diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 4540caf5e7..0745497ddf 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -29,14 +29,12 @@ import bigframes import google.cloud.bigquery import pandas import pyarrow -import sqlglot print(f"Python: {sys.version}") print(f"bigframes=={bigframes.__version__}") print(f"google-cloud-bigquery=={google.cloud.bigquery.__version__}") print(f"pandas=={pandas.__version__}") print(f"pyarrow=={pyarrow.__version__}") -print(f"sqlglot=={sqlglot.__version__}") ``` #### Steps to reproduce diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index b86ae196f6..d2ecfaade0 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes.core import agg_expressions, window_spec from bigframes.core.compile.sqlglot.aggregations import ( diff --git a/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py index 856b5e2f3a..51ff1ceecc 100644 --- a/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py @@ -16,7 +16,7 @@ import typing -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes.core import window_spec import bigframes.core.compile.sqlglot.aggregations.op_registration as reg diff --git a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py index a582a9d4c5..58ab7ec513 100644 --- a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py @@ -16,7 +16,7 @@ import typing -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes.core import window_spec import bigframes.core.compile.sqlglot.aggregations.op_registration as reg diff --git a/bigframes/core/compile/sqlglot/aggregations/op_registration.py b/bigframes/core/compile/sqlglot/aggregations/op_registration.py index a26429f27e..2b3ba20ef0 100644 --- a/bigframes/core/compile/sqlglot/aggregations/op_registration.py +++ b/bigframes/core/compile/sqlglot/aggregations/op_registration.py @@ -16,7 +16,7 @@ import typing -from sqlglot import expressions as sge +from bigframes_vendored.sqlglot import expressions as sge from bigframes.operations import aggregations as agg_ops diff --git a/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py index 594d75fd3c..5feaf794e0 100644 --- a/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py @@ -14,7 +14,7 @@ from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge import bigframes.core.compile.sqlglot.aggregations.op_registration as reg import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index ec711c7fa1..89bb58d7dd 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -16,8 +16,8 @@ import typing +import bigframes_vendored.sqlglot.expressions as sge import pandas as pd -import sqlglot.expressions as sge from bigframes import dtypes from bigframes.core import window_spec diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index d1a68b2ef7..678bb11fbe 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -15,7 +15,7 @@ import typing -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes.core import utils, window_spec import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 870e7064b8..b3b813a1c0 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -17,7 +17,7 @@ import functools import typing -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes.core import ( agg_expressions, diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index a8a36cb6c0..748f15b867 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -16,7 +16,7 @@ from dataclasses import asdict -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import operations as ops from bigframes.core.compile.sqlglot import scalar_compiler diff --git a/bigframes/core/compile/sqlglot/expressions/array_ops.py b/bigframes/core/compile/sqlglot/expressions/array_ops.py index 28b3693caf..e83a6ea99a 100644 --- a/bigframes/core/compile/sqlglot/expressions/array_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/array_ops.py @@ -16,8 +16,8 @@ import typing -import sqlglot as sg -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.string_ops import ( diff --git a/bigframes/core/compile/sqlglot/expressions/blob_ops.py b/bigframes/core/compile/sqlglot/expressions/blob_ops.py index 03708f80c6..03b92c14fe 100644 --- a/bigframes/core/compile/sqlglot/expressions/blob_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/blob_ops.py @@ -14,7 +14,7 @@ from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr diff --git a/bigframes/core/compile/sqlglot/expressions/bool_ops.py b/bigframes/core/compile/sqlglot/expressions/bool_ops.py index 41076b666a..26653d720c 100644 --- a/bigframes/core/compile/sqlglot/expressions/bool_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/bool_ops.py @@ -14,7 +14,7 @@ from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import dtypes from bigframes import operations as ops diff --git a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py index 89d3b4a682..d64c7b1d3f 100644 --- a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py @@ -16,8 +16,8 @@ import typing +import bigframes_vendored.sqlglot.expressions as sge import pandas as pd -import sqlglot.expressions as sge from bigframes import dtypes from bigframes import operations as ops diff --git a/bigframes/core/compile/sqlglot/expressions/constants.py b/bigframes/core/compile/sqlglot/expressions/constants.py index e005a1ed78..f383306292 100644 --- a/bigframes/core/compile/sqlglot/expressions/constants.py +++ b/bigframes/core/compile/sqlglot/expressions/constants.py @@ -14,7 +14,7 @@ import math -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge _ZERO = sge.Cast(this=sge.convert(0), to="INT64") _NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64") diff --git a/bigframes/core/compile/sqlglot/expressions/date_ops.py b/bigframes/core/compile/sqlglot/expressions/date_ops.py index be772d978d..3de7c4b23b 100644 --- a/bigframes/core/compile/sqlglot/expressions/date_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/date_ops.py @@ -14,7 +14,7 @@ from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr diff --git a/bigframes/core/compile/sqlglot/expressions/datetime_ops.py b/bigframes/core/compile/sqlglot/expressions/datetime_ops.py index 78e17ae33b..e20d2da567 100644 --- a/bigframes/core/compile/sqlglot/expressions/datetime_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/datetime_ops.py @@ -14,7 +14,7 @@ from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import dtypes from bigframes import operations as ops diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index e44a1b5c1d..27973ef8b5 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -14,8 +14,8 @@ from __future__ import annotations -import sqlglot as sg -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge from bigframes import dtypes from bigframes import operations as ops diff --git a/bigframes/core/compile/sqlglot/expressions/geo_ops.py b/bigframes/core/compile/sqlglot/expressions/geo_ops.py index 5716dba0e4..a57b4bc931 100644 --- a/bigframes/core/compile/sqlglot/expressions/geo_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/geo_ops.py @@ -14,7 +14,7 @@ from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr diff --git a/bigframes/core/compile/sqlglot/expressions/json_ops.py b/bigframes/core/compile/sqlglot/expressions/json_ops.py index 0a38e8e138..d2008b45bf 100644 --- a/bigframes/core/compile/sqlglot/expressions/json_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/json_ops.py @@ -14,7 +14,7 @@ from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index f7da28c5d2..16f7dec717 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -15,7 +15,7 @@ from __future__ import annotations import bigframes_vendored.constants as bf_constants -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import dtypes from bigframes import operations as ops diff --git a/bigframes/core/compile/sqlglot/expressions/string_ops.py b/bigframes/core/compile/sqlglot/expressions/string_ops.py index 6af9b6a526..abe0145a16 100644 --- a/bigframes/core/compile/sqlglot/expressions/string_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/string_ops.py @@ -17,7 +17,7 @@ import functools import typing -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import dtypes from bigframes import operations as ops diff --git a/bigframes/core/compile/sqlglot/expressions/struct_ops.py b/bigframes/core/compile/sqlglot/expressions/struct_ops.py index b6ec101eb1..5048941f14 100644 --- a/bigframes/core/compile/sqlglot/expressions/struct_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/struct_ops.py @@ -16,9 +16,9 @@ import typing +import bigframes_vendored.sqlglot.expressions as sge import pandas as pd import pyarrow as pa -import sqlglot.expressions as sge from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr diff --git a/bigframes/core/compile/sqlglot/expressions/timedelta_ops.py b/bigframes/core/compile/sqlglot/expressions/timedelta_ops.py index f5b9f891c1..b442fa8175 100644 --- a/bigframes/core/compile/sqlglot/expressions/timedelta_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/timedelta_ops.py @@ -14,7 +14,7 @@ from __future__ import annotations -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import dtypes from bigframes import operations as ops diff --git a/bigframes/core/compile/sqlglot/expressions/typed_expr.py b/bigframes/core/compile/sqlglot/expressions/typed_expr.py index e693dd94a2..4623b8c9b4 100644 --- a/bigframes/core/compile/sqlglot/expressions/typed_expr.py +++ b/bigframes/core/compile/sqlglot/expressions/typed_expr.py @@ -14,7 +14,7 @@ import dataclasses -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes import dtypes diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index 1da58871c7..317141b6cc 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -16,7 +16,7 @@ import functools import typing -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot.expressions as sge from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.sqlglot_ir as ir diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 176564fe23..04176014b0 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -19,13 +19,12 @@ import functools import typing +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge from google.cloud import bigquery import numpy as np import pandas as pd import pyarrow as pa -import sqlglot as sg -import sqlglot.dialects.bigquery -import sqlglot.expressions as sge from bigframes import dtypes from bigframes.core import guid, local_data, schema, utils @@ -48,7 +47,7 @@ class SQLGlotIR: expr: sge.Select = sg.select() """The SQLGlot expression representing the query.""" - dialect = sqlglot.dialects.bigquery.BigQuery + dialect = sg.dialects.bigquery.BigQuery """The SQL dialect used for generation.""" quoted: bool = True diff --git a/bigframes/core/compile/sqlglot/sqlglot_types.py b/bigframes/core/compile/sqlglot/sqlglot_types.py index 64e4363ddf..d22373b303 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_types.py +++ b/bigframes/core/compile/sqlglot/sqlglot_types.py @@ -17,10 +17,10 @@ import typing import bigframes_vendored.constants as constants +import bigframes_vendored.sqlglot as sg import numpy as np import pandas as pd import pyarrow as pa -import sqlglot as sg import bigframes.dtypes diff --git a/noxfile.py b/noxfile.py index 44fc5adede..9abe10f58f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -437,6 +437,8 @@ def doctest(session: nox.sessions.Session): "--ignore", "third_party/bigframes_vendored/ibis", "--ignore", + "third_party/bigframes_vendored/sqlglot", + "--ignore", "bigframes/core/compile/polars", "--ignore", "bigframes/testing", diff --git a/setup.py b/setup.py index fa663f66d5..720687952c 100644 --- a/setup.py +++ b/setup.py @@ -54,8 +54,6 @@ "pydata-google-auth >=1.8.2", "requests >=2.27.1", "shapely >=1.8.5", - # 25.20.0 introduces this fix https://github.com/TobikoData/sqlmesh/issues/3095 for rtrim/ltrim. - "sqlglot >=25.20.0", "tabulate >=0.9", "ipywidgets >=7.7.1", "humanize >=4.6.0", diff --git a/testing/constraints-3.11.txt b/testing/constraints-3.11.txt index 8c274bd9fb..831d22b0ff 100644 --- a/testing/constraints-3.11.txt +++ b/testing/constraints-3.11.txt @@ -520,7 +520,6 @@ sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==2.0.0 sphinxcontrib-serializinghtml==2.0.0 SQLAlchemy==2.0.42 -sqlglot==25.20.2 sqlparse==0.5.3 srsly==2.5.1 stanio==0.5.1 diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index b8dc8697d6..9865d3b364 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -21,7 +21,6 @@ pydata-google-auth==1.8.2 requests==2.27.1 scikit-learn==1.2.2 shapely==1.8.5 -sqlglot==25.20.0 tabulate==0.9 ipywidgets==7.7.1 humanize==4.6.0 diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index e5af45ec2b..b4dc3d2508 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -14,11 +14,9 @@ from unittest import mock -from packaging import version import pandas as pd import pyarrow as pa import pytest -import sqlglot from bigframes import dataframe, dtypes, series import bigframes.bigquery as bbq @@ -67,11 +65,6 @@ def test_ai_function_string_input(session): def test_ai_function_compile_model_params(session): - if version.Version(sqlglot.__version__) < version.Version("25.18.0"): - pytest.skip( - "Skip test because SQLGLot cannot compile model params to JSON at this version." - ) - s1 = bpd.Series(["apple", "bear"], session=session) s2 = bpd.Series(["fruit", "tree"], session=session) prompt = (s1, " is a ", s2) diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py b/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py index dbdeb2307e..c6c1c21151 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from bigframes_vendored.sqlglot import expressions as sge import pytest -from sqlglot import expressions as sge from bigframes.core.compile.sqlglot.aggregations import op_registration from bigframes.operations import aggregations as agg_ops diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_ordered_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_ordered_unary_compiler.py index 2f88fb5d0c..d3a36866f0 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_ordered_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_ordered_unary_compiler.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import typing import pytest @@ -47,12 +46,6 @@ def _apply_ordered_unary_agg_ops( def test_array_agg(scalar_types_df: bpd.DataFrame, snapshot): - # TODO: Verify "NULL LAST" syntax issue on Python < 3.12 - if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - ) - col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_ops.ArrayAggOp().as_expr(col_name) @@ -64,12 +57,6 @@ def test_array_agg(scalar_types_df: bpd.DataFrame, snapshot): def test_string_agg(scalar_types_df: bpd.DataFrame, snapshot): - # TODO: Verify "NULL LAST" syntax issue on Python < 3.12 - if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - ) - col_name = "string_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_ops.StringAggOp(sep=",").as_expr(col_name) diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index fbf631d1a0..c15d70478a 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import typing import pytest @@ -260,10 +259,6 @@ def test_diff_w_timestamp(scalar_types_df: bpd.DataFrame, snapshot): def test_first(scalar_types_df: bpd.DataFrame, snapshot): - if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - ) col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_exprs.UnaryAggregation(agg_ops.FirstOp(), expression.deref(col_name)) @@ -274,10 +269,6 @@ def test_first(scalar_types_df: bpd.DataFrame, snapshot): def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot): - if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - ) col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_exprs.UnaryAggregation( @@ -290,10 +281,6 @@ def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot): def test_last(scalar_types_df: bpd.DataFrame, snapshot): - if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - ) col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_exprs.UnaryAggregation(agg_ops.LastOp(), expression.deref(col_name)) @@ -304,10 +291,6 @@ def test_last(scalar_types_df: bpd.DataFrame, snapshot): def test_last_non_null(scalar_types_df: bpd.DataFrame, snapshot): - if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - ) col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_exprs.UnaryAggregation( @@ -475,11 +458,6 @@ def test_product(scalar_types_df: bpd.DataFrame, snapshot): def test_qcut(scalar_types_df: bpd.DataFrame, snapshot): - if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - ) - col_name = "int64_col" bf = scalar_types_df[[col_name]] bf["qcut_w_int"] = bpd.qcut(bf[col_name], q=4, labels=False, duplicates="drop") diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_windows.py b/tests/unit/core/compile/sqlglot/aggregations/test_windows.py index af347f4aa3..9d0cdd20a0 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_windows.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_windows.py @@ -14,9 +14,9 @@ import unittest +import bigframes_vendored.sqlglot.expressions as sge import pandas as pd import pytest -import sqlglot.expressions as sge from bigframes import dtypes from bigframes.core import window_spec diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 1397c7d6c0..c0cbece905 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -14,9 +14,7 @@ import json -from packaging import version import pytest -import sqlglot from bigframes import dataframe from bigframes import operations as ops @@ -85,11 +83,6 @@ def test_ai_generate_with_output_schema(scalar_types_df: dataframe.DataFrame, sn def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snapshot): - if version.Version(sqlglot.__version__) < version.Version("25.18.0"): - pytest.skip( - "Skip test because SQLGLot cannot compile model params to JSON at this version." - ) - col_name = "string_col" op = ops.AIGenerate( @@ -149,11 +142,6 @@ def test_ai_generate_bool_with_connection_id( def test_ai_generate_bool_with_model_param( scalar_types_df: dataframe.DataFrame, snapshot ): - if version.Version(sqlglot.__version__) < version.Version("25.18.0"): - pytest.skip( - "Skip test because SQLGLot cannot compile model params to JSON at this version." - ) - col_name = "string_col" op = ops.AIGenerateBool( @@ -214,11 +202,6 @@ def test_ai_generate_int_with_connection_id( def test_ai_generate_int_with_model_param( scalar_types_df: dataframe.DataFrame, snapshot ): - if version.Version(sqlglot.__version__) < version.Version("25.18.0"): - pytest.skip( - "Skip test because SQLGLot cannot compile model params to JSON at this version." - ) - col_name = "string_col" op = ops.AIGenerateInt( @@ -280,11 +263,6 @@ def test_ai_generate_double_with_connection_id( def test_ai_generate_double_with_model_param( scalar_types_df: dataframe.DataFrame, snapshot ): - if version.Version(sqlglot.__version__) < version.Version("25.18.0"): - pytest.skip( - "Skip test because SQLGLot cannot compile model params to JSON at this version." - ) - col_name = "string_col" op = ops.AIGenerateDouble( diff --git a/tests/unit/core/compile/sqlglot/test_compile_isin.py b/tests/unit/core/compile/sqlglot/test_compile_isin.py index 94a533abe6..8b3e7f7291 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_isin.py +++ b/tests/unit/core/compile/sqlglot/test_compile_isin.py @@ -12,20 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys - import pytest import bigframes.pandas as bpd pytest.importorskip("pytest_snapshot") -if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - allow_module_level=True, - ) - def test_compile_isin(scalar_types_df: bpd.DataFrame, snapshot): bf_isin = scalar_types_df["int64_col"].isin(scalar_types_df["int64_too"]).to_frame() diff --git a/tests/unit/core/compile/sqlglot/test_compile_window.py b/tests/unit/core/compile/sqlglot/test_compile_window.py index 1fc70dc30f..1602ec2c47 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_window.py +++ b/tests/unit/core/compile/sqlglot/test_compile_window.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys - import numpy as np import pandas as pd import pytest @@ -23,13 +21,6 @@ pytest.importorskip("pytest_snapshot") -if sys.version_info < (3, 12): - pytest.skip( - "Skipping test due to inconsistent SQL formatting on Python < 3.12.", - allow_module_level=True, - ) - - def test_compile_window_w_skips_nulls_op(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col"]].sort_index() # The SumOp's skips_nulls is True diff --git a/tests/unit/core/compile/sqlglot/test_scalar_compiler.py b/tests/unit/core/compile/sqlglot/test_scalar_compiler.py index 14d7b47389..3469d15d74 100644 --- a/tests/unit/core/compile/sqlglot/test_scalar_compiler.py +++ b/tests/unit/core/compile/sqlglot/test_scalar_compiler.py @@ -14,8 +14,8 @@ import unittest.mock as mock +import bigframes_vendored.sqlglot.expressions as sge import pytest -import sqlglot.expressions as sge from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler diff --git a/third_party/bigframes_vendored/ibis/backends/__init__.py b/third_party/bigframes_vendored/ibis/backends/__init__.py index 86a6423d48..23e3f03f4d 100644 --- a/third_party/bigframes_vendored/ibis/backends/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/__init__.py @@ -24,10 +24,10 @@ from collections.abc import Iterable, Iterator, Mapping, MutableMapping from urllib.parse import ParseResult + import bigframes_vendored.sqlglot as sg import pandas as pd import polars as pl import pyarrow as pa - import sqlglot as sg import torch __all__ = ("BaseBackend", "connect") @@ -1257,7 +1257,7 @@ def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str: if dialect is None: return query - import sqlglot as sg + import bigframes_vendored.sqlglot as sg # only transpile if the backend dialect doesn't match the input dialect name = self.name diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py index a87cb081cb..b342c7e4a9 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py @@ -32,14 +32,14 @@ import bigframes_vendored.ibis.expr.operations as ops import bigframes_vendored.ibis.expr.schema as sch import bigframes_vendored.ibis.expr.types as ir +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge import google.api_core.exceptions import google.auth.credentials import google.cloud.bigquery as bq import google.cloud.bigquery_storage_v1 as bqstorage import pydata_google_auth from pydata_google_auth import cache -import sqlglot as sg -import sqlglot.expressions as sge if TYPE_CHECKING: from collections.abc import Iterable, Mapping diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py index 3d214766dc..bac508dc7a 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py @@ -30,14 +30,14 @@ import bigframes_vendored.ibis.expr.operations as ops import bigframes_vendored.ibis.expr.schema as sch import bigframes_vendored.ibis.expr.types as ir +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge import google.api_core.exceptions import google.auth.credentials import google.cloud.bigquery as bq import google.cloud.bigquery_storage_v1 as bqstorage import pydata_google_auth from pydata_google_auth import cache -import sqlglot as sg -import sqlglot.expressions as sge if TYPE_CHECKING: from collections.abc import Iterable, Mapping diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/datatypes.py b/third_party/bigframes_vendored/ibis/backends/bigquery/datatypes.py index fba0339ae9..6039ecdf1b 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/datatypes.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/datatypes.py @@ -6,8 +6,8 @@ import bigframes_vendored.ibis.expr.datatypes as dt import bigframes_vendored.ibis.expr.schema as sch from bigframes_vendored.ibis.formats import SchemaMapper, TypeMapper +import bigframes_vendored.sqlglot as sg import google.cloud.bigquery as bq -import sqlglot as sg _from_bigquery_types = { "INT64": dt.Int64, diff --git a/third_party/bigframes_vendored/ibis/backends/sql/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/__init__.py index 8598e1af72..0e7b31527a 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/__init__.py @@ -14,8 +14,8 @@ import bigframes_vendored.ibis.expr.operations as ops import bigframes_vendored.ibis.expr.schema as sch import bigframes_vendored.ibis.expr.types as ir -import sqlglot as sg -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge if TYPE_CHECKING: from collections.abc import Iterable, Mapping diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py index c01d87fb28..b95e428053 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -30,14 +30,14 @@ import bigframes_vendored.ibis.expr.operations as ops from bigframes_vendored.ibis.expr.operations.udf import InputType from bigframes_vendored.ibis.expr.rewrites import lower_stringslice +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge from public import public -import sqlglot as sg -import sqlglot.expressions as sge try: - from sqlglot.expressions import Alter + from bigframes_vendored.sqlglot.expressions import Alter except ImportError: - from sqlglot.expressions import AlterTable + from bigframes_vendored.sqlglot.expressions import AlterTable else: def AlterTable(*args, kind="TABLE", **kwargs): diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 95d28991a9..1fa5432a16 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -32,10 +32,10 @@ ) import bigframes_vendored.ibis.expr.datatypes as dt import bigframes_vendored.ibis.expr.operations as ops +import bigframes_vendored.sqlglot as sg +from bigframes_vendored.sqlglot.dialects import BigQuery +import bigframes_vendored.sqlglot.expressions as sge import numpy as np -import sqlglot as sg -from sqlglot.dialects import BigQuery -import sqlglot.expressions as sge if TYPE_CHECKING: from collections.abc import Mapping diff --git a/third_party/bigframes_vendored/ibis/backends/sql/datatypes.py b/third_party/bigframes_vendored/ibis/backends/sql/datatypes.py index fce0643783..169871000a 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/datatypes.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/datatypes.py @@ -8,8 +8,8 @@ import bigframes_vendored.ibis.common.exceptions as com import bigframes_vendored.ibis.expr.datatypes as dt from bigframes_vendored.ibis.formats import TypeMapper -import sqlglot as sg -import sqlglot.expressions as sge +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge typecode = sge.DataType.Type diff --git a/third_party/bigframes_vendored/ibis/expr/sql.py b/third_party/bigframes_vendored/ibis/expr/sql.py index 45d9ab6f2f..0d6df4684a 100644 --- a/third_party/bigframes_vendored/ibis/expr/sql.py +++ b/third_party/bigframes_vendored/ibis/expr/sql.py @@ -13,11 +13,11 @@ import bigframes_vendored.ibis.expr.types as ibis_types import bigframes_vendored.ibis.expr.types as ir from bigframes_vendored.ibis.util import experimental +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge +import bigframes_vendored.sqlglot.optimizer as sgo +import bigframes_vendored.sqlglot.planner as sgp from public import public -import sqlglot as sg -import sqlglot.expressions as sge -import sqlglot.optimizer as sgo -import sqlglot.planner as sgp class Catalog(dict[str, sch.Schema]): diff --git a/third_party/bigframes_vendored/sqlglot/__init__.py b/third_party/bigframes_vendored/sqlglot/__init__.py new file mode 100644 index 0000000000..c596720fd9 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/__init__.py @@ -0,0 +1,200 @@ +# ruff: noqa: F401 +""" +.. include:: ../README.md + +---- +""" + +from __future__ import annotations + +import logging +import typing as t + +from bigframes_vendored.sqlglot import expressions as exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect as Dialect # noqa: F401 +from bigframes_vendored.sqlglot.dialects.dialect import ( # noqa: F401 + Dialects as Dialects, +) +from bigframes_vendored.sqlglot.diff import diff as diff # noqa: F401 +from bigframes_vendored.sqlglot.errors import ErrorLevel as ErrorLevel +from bigframes_vendored.sqlglot.errors import ParseError as ParseError +from bigframes_vendored.sqlglot.errors import TokenError as TokenError # noqa: F401 +from bigframes_vendored.sqlglot.errors import ( # noqa: F401 + UnsupportedError as UnsupportedError, +) +from bigframes_vendored.sqlglot.expressions import alias_ as alias # noqa: F401 +from bigframes_vendored.sqlglot.expressions import and_ as and_ # noqa: F401 +from bigframes_vendored.sqlglot.expressions import case as case # noqa: F401 +from bigframes_vendored.sqlglot.expressions import cast as cast # noqa: F401 +from bigframes_vendored.sqlglot.expressions import column as column # noqa: F401 +from bigframes_vendored.sqlglot.expressions import condition as condition # noqa: F401 +from bigframes_vendored.sqlglot.expressions import delete as delete # noqa: F401 +from bigframes_vendored.sqlglot.expressions import except_ as except_ # noqa: F401 +from bigframes_vendored.sqlglot.expressions import ( # noqa: F401 + Expression as Expression, +) +from bigframes_vendored.sqlglot.expressions import ( # noqa: F401 + find_tables as find_tables, +) +from bigframes_vendored.sqlglot.expressions import from_ as from_ # noqa: F401 +from bigframes_vendored.sqlglot.expressions import func as func # noqa: F401 +from bigframes_vendored.sqlglot.expressions import insert as insert # noqa: F401 +from bigframes_vendored.sqlglot.expressions import intersect as intersect # noqa: F401 +from bigframes_vendored.sqlglot.expressions import ( # noqa: F401 + maybe_parse as maybe_parse, +) +from bigframes_vendored.sqlglot.expressions import merge as merge # noqa: F401 +from bigframes_vendored.sqlglot.expressions import not_ as not_ # noqa: F401 +from bigframes_vendored.sqlglot.expressions import or_ as or_ # noqa: F401 +from bigframes_vendored.sqlglot.expressions import select as select # noqa: F401 +from bigframes_vendored.sqlglot.expressions import subquery as subquery # noqa: F401 +from bigframes_vendored.sqlglot.expressions import table_ as table # noqa: F401 +from bigframes_vendored.sqlglot.expressions import to_column as to_column # noqa: F401 +from bigframes_vendored.sqlglot.expressions import ( # noqa: F401 + to_identifier as to_identifier, +) +from bigframes_vendored.sqlglot.expressions import to_table as to_table # noqa: F401 +from bigframes_vendored.sqlglot.expressions import union as union # noqa: F401 +from bigframes_vendored.sqlglot.generator import Generator as Generator # noqa: F401 +from bigframes_vendored.sqlglot.parser import Parser as Parser # noqa: F401 +from bigframes_vendored.sqlglot.schema import ( # noqa: F401 + MappingSchema as MappingSchema, +) +from bigframes_vendored.sqlglot.schema import Schema as Schema # noqa: F401 +from bigframes_vendored.sqlglot.tokens import Token as Token # noqa: F401 +from bigframes_vendored.sqlglot.tokens import Tokenizer as Tokenizer # noqa: F401 +from bigframes_vendored.sqlglot.tokens import TokenType as TokenType # noqa: F401 + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + from bigframes_vendored.sqlglot.dialects.dialect import DialectType as DialectType + +logger = logging.getLogger("sqlglot") + + +try: + from bigframes_vendored.sqlglot._version import ( # noqa: F401 + __version__, + __version_tuple__, + ) +except ImportError: + logger.error( + "Unable to set __version__, run `pip install -e .` or `python setup.py develop` first." + ) + + +pretty = False +"""Whether to format generated SQL by default.""" + + +def tokenize( + sql: str, read: DialectType = None, dialect: DialectType = None +) -> t.List[Token]: + """ + Tokenizes the given SQL string. + + Args: + sql: the SQL code string to tokenize. + read: the SQL dialect to apply during tokenizing (eg. "spark", "hive", "presto", "mysql"). + dialect: the SQL dialect (alias for read). + + Returns: + The resulting list of tokens. + """ + return Dialect.get_or_raise(read or dialect).tokenize(sql) + + +def parse( + sql: str, read: DialectType = None, dialect: DialectType = None, **opts +) -> t.List[t.Optional[Expression]]: + """ + Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement. + + Args: + sql: the SQL code string to parse. + read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + dialect: the SQL dialect (alias for read). + **opts: other `sqlglot.parser.Parser` options. + + Returns: + The resulting syntax tree collection. + """ + return Dialect.get_or_raise(read or dialect).parse(sql, **opts) + + +@t.overload +def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: + ... + + +@t.overload +def parse_one(sql: str, **opts) -> Expression: + ... + + +def parse_one( + sql: str, + read: DialectType = None, + dialect: DialectType = None, + into: t.Optional[exp.IntoType] = None, + **opts, +) -> Expression: + """ + Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. + + Args: + sql: the SQL code string to parse. + read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + dialect: the SQL dialect (alias for read) + into: the SQLGlot Expression to parse into. + **opts: other `sqlglot.parser.Parser` options. + + Returns: + The syntax tree for the first parsed statement. + """ + + dialect = Dialect.get_or_raise(read or dialect) + + if into: + result = dialect.parse_into(into, sql, **opts) + else: + result = dialect.parse(sql, **opts) + + for expression in result: + if not expression: + raise ParseError(f"No expression was parsed from '{sql}'") + return expression + else: + raise ParseError(f"No expression was parsed from '{sql}'") + + +def transpile( + sql: str, + read: DialectType = None, + write: DialectType = None, + identity: bool = True, + error_level: t.Optional[ErrorLevel] = None, + **opts, +) -> t.List[str]: + """ + Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed + to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement. + + Args: + sql: the SQL code string to transpile. + read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql"). + write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql"). + identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both: + the source and the target dialect. + error_level: the desired error level of the parser. + **opts: other `sqlglot.generator.Generator` options. + + Returns: + The list of transpiled SQL statements. + """ + write = (read if write is None else write) if identity else write + write = Dialect.get_or_raise(write) + return [ + write.generate(expression, copy=False, **opts) if expression else "" + for expression in parse(sql, read, error_level=error_level) + ] diff --git a/third_party/bigframes_vendored/sqlglot/__main__.py b/third_party/bigframes_vendored/sqlglot/__main__.py new file mode 100644 index 0000000000..5b979d5484 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/__main__.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import argparse +import sys +import typing as t + +import bigframes_vendored.sqlglot +from bigframes_vendored.sqlglot.helper import to_bool + +parser = argparse.ArgumentParser(description="Transpile SQL") +parser.add_argument( + "sql", + metavar="sql", + type=str, + help="SQL statement(s) to transpile, or - to parse stdin.", +) +parser.add_argument( + "--read", + dest="read", + type=str, + default=None, + help="Dialect to read default is generic", +) +parser.add_argument( + "--write", + dest="write", + type=str, + default=None, + help="Dialect to write default is generic", +) +parser.add_argument( + "--identify", + dest="identify", + type=str, + default="safe", + help="Whether to quote identifiers (safe, true, false)", +) +parser.add_argument( + "--no-pretty", + dest="pretty", + action="store_false", + help="Compress sql", +) +parser.add_argument( + "--parse", + dest="parse", + action="store_true", + help="Parse and return the expression tree", +) +parser.add_argument( + "--tokenize", + dest="tokenize", + action="store_true", + help="Tokenize and return the tokens list", +) +parser.add_argument( + "--error-level", + dest="error_level", + type=str, + default="IMMEDIATE", + help="IGNORE, WARN, RAISE, IMMEDIATE (default)", +) +parser.add_argument( + "--version", + action="version", + version=bigframes_vendored.sqlglot.__version__, + help="Display the SQLGlot version", +) + + +args = parser.parse_args() +error_level = bigframes_vendored.sqlglot.ErrorLevel[args.error_level.upper()] + +sql = sys.stdin.read() if args.sql == "-" else args.sql + +if args.parse: + objs: t.Union[t.List[str], t.List[bigframes_vendored.sqlglot.tokens.Token]] = [ + repr(expression) + for expression in bigframes_vendored.sqlglot.parse( + sql, + read=args.read, + error_level=error_level, + ) + ] +elif args.tokenize: + objs = bigframes_vendored.sqlglot.Dialect.get_or_raise(args.read).tokenize(sql) +else: + objs = bigframes_vendored.sqlglot.transpile( + sql, + read=args.read, + write=args.write, + identify="safe" if args.identify == "safe" else to_bool(args.identify), + pretty=args.pretty, + error_level=error_level, + ) + +for obj in objs: + print(obj) diff --git a/third_party/bigframes_vendored/sqlglot/_typing.py b/third_party/bigframes_vendored/sqlglot/_typing.py new file mode 100644 index 0000000000..d374940739 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/_typing.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import typing as t + +import bigframes_vendored.sqlglot + +if t.TYPE_CHECKING: + from typing_extensions import Literal as Lit # noqa + +# A little hack for backwards compatibility with Python 3.7. +# For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed. +# But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency. +A = t.TypeVar("A", bound=t.Any) +B = t.TypeVar("B", bound="bigframes_vendored.sqlglot.exp.Binary") +E = t.TypeVar("E", bound="bigframes_vendored.sqlglot.exp.Expression") +F = t.TypeVar("F", bound="bigframes_vendored.sqlglot.exp.Func") +T = t.TypeVar("T") diff --git a/third_party/bigframes_vendored/sqlglot/dialects/__init__.py b/third_party/bigframes_vendored/sqlglot/dialects/__init__.py new file mode 100644 index 0000000000..cc835959ae --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/dialects/__init__.py @@ -0,0 +1,97 @@ +# ruff: noqa: F401 +""" +## Dialects + +While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult +to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible +SQL transpilation framework. + +The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible. + +Each SQL variation has its own `Dialect` subclass, extending the corresponding `Tokenizer`, `Parser` and `Generator` +classes as needed. + +### Implementing a custom Dialect + +Creating a new SQL dialect may seem complicated at first, but it is actually quite simple in SQLGlot: + +```python +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator +from sqlglot.tokens import Tokenizer, TokenType + + +class Custom(Dialect): + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] # Strings can be delimited by either single or double quotes + IDENTIFIERS = ["`"] # Identifiers can be delimited by backticks + + # Associates certain meaningful words with tokens that capture their intent + KEYWORDS = { + **Tokenizer.KEYWORDS, + "INT64": TokenType.BIGINT, + "FLOAT64": TokenType.DOUBLE, + } + + class Generator(Generator): + # Specifies how AST nodes, i.e. subclasses of exp.Expression, should be converted into SQL + TRANSFORMS = { + exp.Array: lambda self, e: f"[{self.expressions(e)}]", + } + + # Specifies how AST nodes representing data types should be converted into SQL + TYPE_MAPPING = { + exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.SMALLINT: "INT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.DOUBLE: "FLOAT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.TEXT: "STRING", + } +``` + +The above example demonstrates how certain parts of the base `Dialect` class can be overridden to match a different +specification. Even though it is a fairly realistic starting point, we strongly encourage the reader to study existing +dialect implementations in order to understand how their various components can be modified, depending on the use-case. + +---- +""" + +import importlib +import threading + +DIALECTS = [ + "BigQuery", +] + +MODULE_BY_DIALECT = {name: name.lower() for name in DIALECTS} +DIALECT_MODULE_NAMES = MODULE_BY_DIALECT.values() + +MODULE_BY_ATTRIBUTE = { + **MODULE_BY_DIALECT, + "Dialect": "dialect", + "Dialects": "dialect", +} + +__all__ = list(MODULE_BY_ATTRIBUTE) + +# We use a reentrant lock because a dialect may depend on (i.e., import) other dialects. +# Without it, the first dialect import would never be completed, because subsequent +# imports would be blocked on the lock held by the first import. +_import_lock = threading.RLock() + + +def __getattr__(name): + module_name = MODULE_BY_ATTRIBUTE.get(name) + if module_name: + with _import_lock: + module = importlib.import_module( + f"bigframes_vendored.sqlglot.dialects.{module_name}" + ) + return getattr(module, name) + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/third_party/bigframes_vendored/sqlglot/dialects/bigquery.py b/third_party/bigframes_vendored/sqlglot/dialects/bigquery.py new file mode 100644 index 0000000000..6fd0021e02 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/dialects/bigquery.py @@ -0,0 +1,1680 @@ +from __future__ import annotations + +import logging +import re +import typing as t + +from bigframes_vendored.sqlglot import ( + exp, + generator, + jsonpath, + parser, + tokens, + transforms, +) +from bigframes_vendored.sqlglot._typing import E +from bigframes_vendored.sqlglot.dialects.dialect import ( + arg_max_or_min_no_count, + binary_from_function, + build_date_delta_with_interval, + build_formatted_time, + date_add_interval_sql, + datestrtodate_sql, + Dialect, + filter_array_using_unnest, + groupconcat_sql, + if_sql, + inline_array_unless_query, + max_or_greatest, + min_or_least, + no_ilike_sql, + NormalizationStrategy, + regexp_replace_sql, + rename_func, + sha2_digest_sql, + sha256_sql, + strposition_sql, + timestrtotime_sql, + ts_or_ds_add_cast, + unit_to_var, +) +from bigframes_vendored.sqlglot.generator import unsupported_args +from bigframes_vendored.sqlglot.helper import seq_get, split_num_words +from bigframes_vendored.sqlglot.optimizer.annotate_types import TypeAnnotator +from bigframes_vendored.sqlglot.tokens import TokenType +from bigframes_vendored.sqlglot.typing.bigquery import EXPRESSION_METADATA + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import Lit + +logger = logging.getLogger("sqlglot") + + +JSON_EXTRACT_TYPE = t.Union[ + exp.JSONExtract, exp.JSONExtractScalar, exp.JSONExtractArray +] + +DQUOTES_ESCAPING_JSON_FUNCTIONS = ("JSON_QUERY", "JSON_VALUE", "JSON_QUERY_ARRAY") + +MAKE_INTERVAL_KWARGS = ["year", "month", "day", "hour", "minute", "second"] + + +def _derived_table_values_to_unnest( + self: BigQuery.Generator, expression: exp.Values +) -> str: + if not expression.find_ancestor(exp.From, exp.Join): + return self.values_sql(expression) + + structs = [] + alias = expression.args.get("alias") + for tup in expression.find_all(exp.Tuple): + field_aliases = ( + alias.columns + if alias and alias.columns + else (f"_c{i}" for i in range(len(tup.expressions))) + ) + expressions = [ + exp.PropertyEQ(this=exp.to_identifier(name), expression=fld) + for name, fld in zip(field_aliases, tup.expressions) + ] + structs.append(exp.Struct(expressions=expressions)) + + # Due to `UNNEST_COLUMN_ONLY`, it is expected that the table alias be contained in the columns expression + alias_name_only = exp.TableAlias(columns=[alias.this]) if alias else None + return self.unnest_sql( + exp.Unnest(expressions=[exp.array(*structs, copy=False)], alias=alias_name_only) + ) + + +def _returnsproperty_sql( + self: BigQuery.Generator, expression: exp.ReturnsProperty +) -> str: + this = expression.this + if isinstance(this, exp.Schema): + this = f"{self.sql(this, 'this')} <{self.expressions(this)}>" + else: + this = self.sql(this) + return f"RETURNS {this}" + + +def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: + returns = expression.find(exp.ReturnsProperty) + if expression.kind == "FUNCTION" and returns and returns.args.get("is_table"): + expression.set("kind", "TABLE FUNCTION") + + if isinstance(expression.expression, (exp.Subquery, exp.Literal)): + expression.set("expression", expression.expression.this) + + return self.create_sql(expression) + + +# https://issuetracker.google.com/issues/162294746 +# workaround for bigquery bug when grouping by an expression and then ordering +# WITH x AS (SELECT 1 y) +# SELECT y + 1 z +# FROM x +# GROUP BY x + 1 +# ORDER by z +def _alias_ordered_group(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Select): + group = expression.args.get("group") + order = expression.args.get("order") + + if group and order: + aliases = { + select.this: select.args["alias"] + for select in expression.selects + if isinstance(select, exp.Alias) + } + + for grouped in group.expressions: + if grouped.is_int: + continue + alias = aliases.get(grouped) + if alias: + grouped.replace(exp.column(alias)) + + return expression + + +def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: + """BigQuery doesn't allow column names when defining a CTE, so we try to push them down.""" + if isinstance(expression, exp.CTE) and expression.alias_column_names: + cte_query = expression.this + + if cte_query.is_star: + logger.warning( + "Can't push down CTE column names for star queries. Run the query through" + " the optimizer or use 'qualify' to expand the star projections first." + ) + return expression + + column_names = expression.alias_column_names + expression.args["alias"].set("columns", None) + + for name, select in zip(column_names, cte_query.selects): + to_replace = select + + if isinstance(select, exp.Alias): + select = select.this + + # Inner aliases are shadowed by the CTE column names + to_replace.replace(exp.alias_(select, name)) + + return expression + + +def _build_parse_timestamp(args: t.List) -> exp.StrToTime: + this = build_formatted_time(exp.StrToTime, "bigquery")( + [seq_get(args, 1), seq_get(args, 0)] + ) + this.set("zone", seq_get(args, 2)) + return this + + +def _build_timestamp(args: t.List) -> exp.Timestamp: + timestamp = exp.Timestamp.from_arg_list(args) + timestamp.set("with_tz", True) + return timestamp + + +def _build_date(args: t.List) -> exp.Date | exp.DateFromParts: + expr_type = exp.DateFromParts if len(args) == 3 else exp.Date + return expr_type.from_arg_list(args) + + +def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5: + # TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation + arg = seq_get(args, 0) + return ( + exp.MD5(this=arg.this) + if isinstance(arg, exp.MD5Digest) + else exp.LowerHex(this=arg) + ) + + +def _build_json_strip_nulls(args: t.List) -> exp.JSONStripNulls: + expression = exp.JSONStripNulls(this=seq_get(args, 0)) + + for arg in args[1:]: + if isinstance(arg, exp.Kwarg): + expression.set(arg.this.name.lower(), arg) + else: + expression.set("expression", arg) + + return expression + + +def _array_contains_sql(self: BigQuery.Generator, expression: exp.ArrayContains) -> str: + return self.sql( + exp.Exists( + this=exp.select("1") + .from_( + exp.Unnest(expressions=[expression.left]).as_("_unnest", table=["_col"]) + ) + .where(exp.column("_col").eq(expression.right)) + ) + ) + + +def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> str: + return date_add_interval_sql("DATE", "ADD")(self, ts_or_ds_add_cast(expression)) + + +def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str: + expression.this.replace(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) + expression.expression.replace( + exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP) + ) + unit = unit_to_var(expression) + return self.func("DATE_DIFF", expression.this, expression.expression, unit) + + +def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str: + scale = expression.args.get("scale") + timestamp = expression.this + + if scale in (None, exp.UnixToTime.SECONDS): + return self.func("TIMESTAMP_SECONDS", timestamp) + if scale == exp.UnixToTime.MILLIS: + return self.func("TIMESTAMP_MILLIS", timestamp) + if scale == exp.UnixToTime.MICROS: + return self.func("TIMESTAMP_MICROS", timestamp) + + unix_seconds = exp.cast( + exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), + exp.DataType.Type.BIGINT, + ) + return self.func("TIMESTAMP_SECONDS", unix_seconds) + + +def _build_time(args: t.List) -> exp.Func: + if len(args) == 1: + return exp.TsOrDsToTime(this=args[0]) + if len(args) == 2: + return exp.Time.from_arg_list(args) + return exp.TimeFromParts.from_arg_list(args) + + +def _build_datetime(args: t.List) -> exp.Func: + if len(args) == 1: + return exp.TsOrDsToDatetime.from_arg_list(args) + if len(args) == 2: + return exp.Datetime.from_arg_list(args) + return exp.TimestampFromParts.from_arg_list(args) + + +def build_date_diff(args: t.List) -> exp.Expression: + expr = exp.DateDiff( + this=seq_get(args, 0), + expression=seq_get(args, 1), + unit=seq_get(args, 2), + date_part_boundary=True, + ) + + # Normalize plain WEEK to WEEK(SUNDAY) to preserve the semantic in the AST to facilitate transpilation + # This is done post exp.DateDiff construction since the TimeUnit mixin performs canonicalizations in its constructor too + unit = expr.args.get("unit") + + if isinstance(unit, exp.Var) and unit.name.upper() == "WEEK": + expr.set("unit", exp.WeekStart(this=exp.var("SUNDAY"))) + + return expr + + +def _build_regexp_extract( + expr_type: t.Type[E], default_group: t.Optional[exp.Expression] = None +) -> t.Callable[[t.List, BigQuery], E]: + def _builder(args: t.List, dialect: BigQuery) -> E: + try: + group = re.compile(args[1].name).groups == 1 + except re.error: + group = False + + # Default group is used for the transpilation of REGEXP_EXTRACT_ALL + return expr_type( + this=seq_get(args, 0), + expression=seq_get(args, 1), + position=seq_get(args, 2), + occurrence=seq_get(args, 3), + group=exp.Literal.number(1) if group else default_group, + **( + { + "null_if_pos_overflow": dialect.REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL + } + if expr_type is exp.RegexpExtract + else {} + ), + ) + + return _builder + + +def _build_extract_json_with_default_path( + expr_type: t.Type[E], +) -> t.Callable[[t.List, Dialect], E]: + def _builder(args: t.List, dialect: Dialect) -> E: + if len(args) == 1: + # The default value for the JSONPath is '$' i.e all of the data + args.append(exp.Literal.string("$")) + return parser.build_extract_json_with_path(expr_type)(args, dialect) + + return _builder + + +def _str_to_datetime_sql( + self: BigQuery.Generator, expression: exp.StrToDate | exp.StrToTime +) -> str: + this = self.sql(expression, "this") + dtype = "DATE" if isinstance(expression, exp.StrToDate) else "TIMESTAMP" + + if expression.args.get("safe"): + fmt = self.format_time( + expression, + self.dialect.INVERSE_FORMAT_MAPPING, + self.dialect.INVERSE_FORMAT_TRIE, + ) + return f"SAFE_CAST({this} AS {dtype} FORMAT {fmt})" + + fmt = self.format_time(expression) + return self.func(f"PARSE_{dtype}", fmt, this, expression.args.get("zone")) + + +@unsupported_args("ins_cost", "del_cost", "sub_cost") +def _levenshtein_sql(self: BigQuery.Generator, expression: exp.Levenshtein) -> str: + max_dist = expression.args.get("max_dist") + if max_dist: + max_dist = exp.Kwarg(this=exp.var("max_distance"), expression=max_dist) + + return self.func("EDIT_DISTANCE", expression.this, expression.expression, max_dist) + + +def _build_levenshtein(args: t.List) -> exp.Levenshtein: + max_dist = seq_get(args, 2) + return exp.Levenshtein( + this=seq_get(args, 0), + expression=seq_get(args, 1), + max_dist=max_dist.expression if max_dist else None, + ) + + +def _build_format_time( + expr_type: t.Type[exp.Expression], +) -> t.Callable[[t.List], exp.TimeToStr]: + def _builder(args: t.List) -> exp.TimeToStr: + formatted_time = build_formatted_time(exp.TimeToStr, "bigquery")( + [expr_type(this=seq_get(args, 1)), seq_get(args, 0)] + ) + formatted_time.set("zone", seq_get(args, 2)) + return formatted_time + + return _builder + + +def _build_contains_substring(args: t.List) -> exp.Contains: + # Lowercase the operands in case of transpilation, as exp.Contains + # is case-sensitive on other dialects + this = exp.Lower(this=seq_get(args, 0)) + expr = exp.Lower(this=seq_get(args, 1)) + + return exp.Contains(this=this, expression=expr, json_scope=seq_get(args, 2)) + + +def _json_extract_sql(self: BigQuery.Generator, expression: JSON_EXTRACT_TYPE) -> str: + name = (expression._meta and expression.meta.get("name")) or expression.sql_name() + upper = name.upper() + + dquote_escaping = upper in DQUOTES_ESCAPING_JSON_FUNCTIONS + + if dquote_escaping: + self._quote_json_path_key_using_brackets = False + + sql = rename_func(upper)(self, expression) + + if dquote_escaping: + self._quote_json_path_key_using_brackets = True + + return sql + + +class BigQuery(Dialect): + WEEK_OFFSET = -1 + UNNEST_COLUMN_ONLY = True + SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False + LOG_BASE_FIRST = False + HEX_LOWERCASE = True + FORCE_EARLY_ALIAS_REF_EXPANSION = True + EXPAND_ONLY_GROUP_ALIAS_REF = True + PRESERVE_ORIGINAL_NAMES = True + HEX_STRING_IS_INTEGER_TYPE = True + BYTE_STRING_IS_BYTES_TYPE = True + UUID_IS_STRING_TYPE = True + ANNOTATE_ALL_SCOPES = True + PROJECTION_ALIASES_SHADOW_SOURCE_NAMES = True + TABLES_REFERENCEABLE_AS_COLUMNS = True + SUPPORTS_STRUCT_STAR_EXPANSION = True + EXCLUDES_PSEUDOCOLUMNS_FROM_STAR = True + QUERY_RESULTS_ARE_STRUCTS = True + JSON_EXTRACT_SCALAR_SCALAR_ONLY = True + LEAST_GREATEST_IGNORES_NULLS = False + DEFAULT_NULL_TYPE = exp.DataType.Type.BIGINT + PRIORITIZE_NON_LITERAL_TYPES = True + + # https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#initcap + INITCAP_DEFAULT_DELIMITER_CHARS = ' \t\n\r\f\v\\[\\](){}/|<>!?@"^#$&~_,.:;*%+\\-' + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE + + # bigquery udfs are case sensitive + NORMALIZE_FUNCTIONS = False + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_elements_date_time + TIME_MAPPING = { + "%x": "%m/%d/%y", + "%D": "%m/%d/%y", + "%E6S": "%S.%f", + "%e": "%-d", + "%F": "%Y-%m-%d", + "%T": "%H:%M:%S", + "%c": "%a %b %e %H:%M:%S %Y", + } + + INVERSE_TIME_MAPPING = { + # Preserve %E6S instead of expanding to %T.%f - since both %E6S & %T.%f are semantically different in BigQuery + # %E6S is semantically different from %T.%f: %E6S works as a single atomic specifier for seconds with microseconds, while %T.%f expands incorrectly and fails to parse. + "%H:%M:%S.%f": "%H:%M:%E6S", + } + + FORMAT_MAPPING = { + "DD": "%d", + "MM": "%m", + "MON": "%b", + "MONTH": "%B", + "YYYY": "%Y", + "YY": "%y", + "HH": "%I", + "HH12": "%I", + "HH24": "%H", + "MI": "%M", + "SS": "%S", + "SSSSS": "%f", + "TZH": "%z", + } + + # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement + # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table + # https://cloud.google.com/bigquery/docs/querying-wildcard-tables#scanning_a_range_of_tables_using_table_suffix + # https://cloud.google.com/bigquery/docs/query-cloud-storage-data#query_the_file_name_pseudo-column + PSEUDOCOLUMNS = { + "_PARTITIONTIME", + "_PARTITIONDATE", + "_TABLE_SUFFIX", + "_FILE_NAME", + "_DBT_MAX_PARTITION", + } + + # All set operations require either a DISTINCT or ALL specifier + SET_OP_DISTINCT_BY_DEFAULT = dict.fromkeys( + (exp.Except, exp.Intersect, exp.Union), None + ) + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/navigation_functions#percentile_cont + COERCES_TO = { + **TypeAnnotator.COERCES_TO, + exp.DataType.Type.BIGDECIMAL: {exp.DataType.Type.DOUBLE}, + } + COERCES_TO[exp.DataType.Type.DECIMAL] |= {exp.DataType.Type.BIGDECIMAL} + COERCES_TO[exp.DataType.Type.BIGINT] |= {exp.DataType.Type.BIGDECIMAL} + COERCES_TO[exp.DataType.Type.VARCHAR] |= { + exp.DataType.Type.DATE, + exp.DataType.Type.DATETIME, + exp.DataType.Type.TIME, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.TIMESTAMPTZ, + } + + EXPRESSION_METADATA = EXPRESSION_METADATA.copy() + + def normalize_identifier(self, expression: E) -> E: + if ( + isinstance(expression, exp.Identifier) + and self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE + ): + parent = expression.parent + while isinstance(parent, exp.Dot): + parent = parent.parent + + # In BigQuery, CTEs are case-insensitive, but UDF and table names are case-sensitive + # by default. The following check uses a heuristic to detect tables based on whether + # they are qualified. This should generally be correct, because tables in BigQuery + # must be qualified with at least a dataset, unless @@dataset_id is set. + case_sensitive = ( + isinstance(parent, exp.UserDefinedFunction) + or ( + isinstance(parent, exp.Table) + and parent.db + and ( + parent.meta.get("quoted_table") + or not parent.meta.get("maybe_column") + ) + ) + or expression.meta.get("is_table") + ) + if not case_sensitive: + expression.set("this", expression.this.lower()) + + return t.cast(E, expression) + + return super().normalize_identifier(expression) + + class JSONPathTokenizer(jsonpath.JSONPathTokenizer): + VAR_TOKENS = { + TokenType.DASH, + TokenType.VAR, + } + + class Tokenizer(tokens.Tokenizer): + QUOTES = ["'", '"', '"""', "'''"] + COMMENTS = ["--", "#", ("/*", "*/")] + IDENTIFIERS = ["`"] + STRING_ESCAPES = ["\\"] + + HEX_STRINGS = [("0x", ""), ("0X", "")] + + BYTE_STRINGS = [ + (prefix + q, q) + for q in t.cast(t.List[str], QUOTES) + for prefix in ("b", "B") + ] + + RAW_STRINGS = [ + (prefix + q, q) + for q in t.cast(t.List[str], QUOTES) + for prefix in ("r", "R") + ] + + NESTED_COMMENTS = False + + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + "ANY TYPE": TokenType.VARIANT, + "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, + "BYTEINT": TokenType.INT, + "BYTES": TokenType.BINARY, + "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, + "DATETIME": TokenType.TIMESTAMP, + "DECLARE": TokenType.DECLARE, + "ELSEIF": TokenType.COMMAND, + "EXCEPTION": TokenType.COMMAND, + "EXPORT": TokenType.EXPORT, + "FLOAT64": TokenType.DOUBLE, + "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, + "LOOP": TokenType.COMMAND, + "MODEL": TokenType.MODEL, + "NOT DETERMINISTIC": TokenType.VOLATILE, + "RECORD": TokenType.STRUCT, + "REPEAT": TokenType.COMMAND, + "TIMESTAMP": TokenType.TIMESTAMPTZ, + "WHILE": TokenType.COMMAND, + } + KEYWORDS.pop("DIV") + KEYWORDS.pop("VALUES") + KEYWORDS.pop("/*+") + + class Parser(parser.Parser): + PREFIXED_PIVOT_COLUMNS = True + LOG_DEFAULTS_TO_LN = True + SUPPORTS_IMPLICIT_UNNEST = True + JOINS_HAVE_EQUAL_PRECEDENCE = True + + # BigQuery does not allow ASC/DESC to be used as an identifier, allows GRANT as an identifier + ID_VAR_TOKENS = { + *parser.Parser.ID_VAR_TOKENS, + TokenType.GRANT, + } - {TokenType.ASC, TokenType.DESC} + + ALIAS_TOKENS = { + *parser.Parser.ALIAS_TOKENS, + TokenType.GRANT, + } - {TokenType.ASC, TokenType.DESC} + + TABLE_ALIAS_TOKENS = { + *parser.Parser.TABLE_ALIAS_TOKENS, + TokenType.GRANT, + } - {TokenType.ASC, TokenType.DESC} + + COMMENT_TABLE_ALIAS_TOKENS = { + *parser.Parser.COMMENT_TABLE_ALIAS_TOKENS, + TokenType.GRANT, + } - {TokenType.ASC, TokenType.DESC} + + UPDATE_ALIAS_TOKENS = { + *parser.Parser.UPDATE_ALIAS_TOKENS, + TokenType.GRANT, + } - {TokenType.ASC, TokenType.DESC} + + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "APPROX_TOP_COUNT": exp.ApproxTopK.from_arg_list, + "BIT_AND": exp.BitwiseAndAgg.from_arg_list, + "BIT_OR": exp.BitwiseOrAgg.from_arg_list, + "BIT_XOR": exp.BitwiseXorAgg.from_arg_list, + "BIT_COUNT": exp.BitwiseCount.from_arg_list, + "BOOL": exp.JSONBool.from_arg_list, + "CONTAINS_SUBSTR": _build_contains_substring, + "DATE": _build_date, + "DATE_ADD": build_date_delta_with_interval(exp.DateAdd), + "DATE_DIFF": build_date_diff, + "DATE_SUB": build_date_delta_with_interval(exp.DateSub), + "DATE_TRUNC": lambda args: exp.DateTrunc( + unit=seq_get(args, 1), + this=seq_get(args, 0), + zone=seq_get(args, 2), + ), + "DATETIME": _build_datetime, + "DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd), + "DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub), + "DIV": binary_from_function(exp.IntDiv), + "EDIT_DISTANCE": _build_levenshtein, + "FORMAT_DATE": _build_format_time(exp.TsOrDsToDate), + "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list, + "JSON_EXTRACT_SCALAR": _build_extract_json_with_default_path( + exp.JSONExtractScalar + ), + "JSON_EXTRACT_ARRAY": _build_extract_json_with_default_path( + exp.JSONExtractArray + ), + "JSON_EXTRACT_STRING_ARRAY": _build_extract_json_with_default_path( + exp.JSONValueArray + ), + "JSON_KEYS": exp.JSONKeysAtDepth.from_arg_list, + "JSON_QUERY": parser.build_extract_json_with_path(exp.JSONExtract), + "JSON_QUERY_ARRAY": _build_extract_json_with_default_path( + exp.JSONExtractArray + ), + "JSON_STRIP_NULLS": _build_json_strip_nulls, + "JSON_VALUE": _build_extract_json_with_default_path(exp.JSONExtractScalar), + "JSON_VALUE_ARRAY": _build_extract_json_with_default_path( + exp.JSONValueArray + ), + "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True), + "MD5": exp.MD5Digest.from_arg_list, + "SHA1": exp.SHA1Digest.from_arg_list, + "NORMALIZE_AND_CASEFOLD": lambda args: exp.Normalize( + this=seq_get(args, 0), form=seq_get(args, 1), is_casefold=True + ), + "OCTET_LENGTH": exp.ByteLength.from_arg_list, + "TO_HEX": _build_to_hex, + "PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")( + [seq_get(args, 1), seq_get(args, 0)] + ), + "PARSE_TIME": lambda args: build_formatted_time(exp.ParseTime, "bigquery")( + [seq_get(args, 1), seq_get(args, 0)] + ), + "PARSE_TIMESTAMP": _build_parse_timestamp, + "PARSE_DATETIME": lambda args: build_formatted_time( + exp.ParseDatetime, "bigquery" + )([seq_get(args, 1), seq_get(args, 0)]), + "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, + "REGEXP_EXTRACT": _build_regexp_extract(exp.RegexpExtract), + "REGEXP_SUBSTR": _build_regexp_extract(exp.RegexpExtract), + "REGEXP_EXTRACT_ALL": _build_regexp_extract( + exp.RegexpExtractAll, default_group=exp.Literal.number(0) + ), + "SHA256": lambda args: exp.SHA2Digest( + this=seq_get(args, 0), length=exp.Literal.number(256) + ), + "SHA512": lambda args: exp.SHA2( + this=seq_get(args, 0), length=exp.Literal.number(512) + ), + "SPLIT": lambda args: exp.Split( + # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split + this=seq_get(args, 0), + expression=seq_get(args, 1) or exp.Literal.string(","), + ), + "STRPOS": exp.StrPosition.from_arg_list, + "TIME": _build_time, + "TIME_ADD": build_date_delta_with_interval(exp.TimeAdd), + "TIME_SUB": build_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP": _build_timestamp, + "TIMESTAMP_ADD": build_date_delta_with_interval(exp.TimestampAdd), + "TIMESTAMP_SUB": build_date_delta_with_interval(exp.TimestampSub), + "TIMESTAMP_MICROS": lambda args: exp.UnixToTime( + this=seq_get(args, 0), scale=exp.UnixToTime.MICROS + ), + "TIMESTAMP_MILLIS": lambda args: exp.UnixToTime( + this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS + ), + "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)), + "TO_JSON": lambda args: exp.JSONFormat( + this=seq_get(args, 0), options=seq_get(args, 1), to_json=True + ), + "TO_JSON_STRING": exp.JSONFormat.from_arg_list, + "FORMAT_DATETIME": _build_format_time(exp.TsOrDsToDatetime), + "FORMAT_TIMESTAMP": _build_format_time(exp.TsOrDsToTimestamp), + "FORMAT_TIME": _build_format_time(exp.TsOrDsToTime), + "FROM_HEX": exp.Unhex.from_arg_list, + "WEEK": lambda args: exp.WeekStart(this=exp.var(seq_get(args, 0))), + } + # Remove SEARCH to avoid parameter routing issues - let it fall back to Anonymous function + FUNCTIONS.pop("SEARCH") + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "ARRAY": lambda self: self.expression( + exp.Array, + expressions=[self._parse_statement()], + struct_name_inheritance=True, + ), + "JSON_ARRAY": lambda self: self.expression( + exp.JSONArray, expressions=self._parse_csv(self._parse_bitwise) + ), + "MAKE_INTERVAL": lambda self: self._parse_make_interval(), + "PREDICT": lambda self: self._parse_ml(exp.Predict), + "TRANSLATE": lambda self: self._parse_translate(), + "FEATURES_AT_TIME": lambda self: self._parse_features_at_time(), + "GENERATE_EMBEDDING": lambda self: self._parse_ml(exp.GenerateEmbedding), + "GENERATE_TEXT_EMBEDDING": lambda self: self._parse_ml( + exp.GenerateEmbedding, is_text=True + ), + "VECTOR_SEARCH": lambda self: self._parse_vector_search(), + "FORECAST": lambda self: self._parse_ml(exp.MLForecast), + } + FUNCTION_PARSERS.pop("TRIM") + + NO_PAREN_FUNCTIONS = { + **parser.Parser.NO_PAREN_FUNCTIONS, + TokenType.CURRENT_DATETIME: exp.CurrentDatetime, + } + + NESTED_TYPE_TOKENS = { + *parser.Parser.NESTED_TYPE_TOKENS, + TokenType.TABLE, + } + + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + "NOT DETERMINISTIC": lambda self: self.expression( + exp.StabilityProperty, this=exp.Literal.string("VOLATILE") + ), + "OPTIONS": lambda self: self._parse_with_property(), + } + + CONSTRAINT_PARSERS = { + **parser.Parser.CONSTRAINT_PARSERS, + "OPTIONS": lambda self: exp.Properties( + expressions=self._parse_with_property() + ), + } + + RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy() + RANGE_PARSERS.pop(TokenType.OVERLAPS) + + DASHED_TABLE_PART_FOLLOW_TOKENS = { + TokenType.DOT, + TokenType.L_PAREN, + TokenType.R_PAREN, + } + + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.ELSE: lambda self: self._parse_as_command(self._prev), + TokenType.END: lambda self: self._parse_as_command(self._prev), + TokenType.FOR: lambda self: self._parse_for_in(), + TokenType.EXPORT: lambda self: self._parse_export_data(), + TokenType.DECLARE: lambda self: self._parse_declare(), + } + + BRACKET_OFFSETS = { + "OFFSET": (0, False), + "ORDINAL": (1, False), + "SAFE_OFFSET": (0, True), + "SAFE_ORDINAL": (1, True), + } + + def _parse_for_in(self) -> t.Union[exp.ForIn, exp.Command]: + index = self._index + this = self._parse_range() + self._match_text_seq("DO") + if self._match(TokenType.COMMAND): + self._retreat(index) + return self._parse_as_command(self._prev) + return self.expression( + exp.ForIn, this=this, expression=self._parse_statement() + ) + + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: + this = super()._parse_table_part(schema=schema) or self._parse_number() + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names + if isinstance(this, exp.Identifier): + table_name = this.name + while self._match(TokenType.DASH, advance=False) and self._next: + start = self._curr + while self._is_connected() and not self._match_set( + self.DASHED_TABLE_PART_FOLLOW_TOKENS, advance=False + ): + self._advance() + + if start == self._curr: + break + + table_name += self._find_sql(start, self._prev) + + this = exp.Identifier( + this=table_name, quoted=this.args.get("quoted") + ).update_positions(this) + elif isinstance(this, exp.Literal): + table_name = this.name + + if self._is_connected() and self._parse_var(any_token=True): + table_name += self._prev.text + + this = exp.Identifier(this=table_name, quoted=True).update_positions( + this + ) + + return this + + def _parse_table_parts( + self, + schema: bool = False, + is_db_reference: bool = False, + wildcard: bool = False, + ) -> exp.Table: + table = super()._parse_table_parts( + schema=schema, is_db_reference=is_db_reference, wildcard=True + ) + + # proj-1.db.tbl -- `1.` is tokenized as a float so we need to unravel it here + if not table.catalog: + if table.db: + previous_db = table.args["db"] + parts = table.db.split(".") + if len(parts) == 2 and not table.args["db"].quoted: + table.set( + "catalog", + exp.Identifier(this=parts[0]).update_positions(previous_db), + ) + table.set( + "db", + exp.Identifier(this=parts[1]).update_positions(previous_db), + ) + else: + previous_this = table.this + parts = table.name.split(".") + if len(parts) == 2 and not table.this.quoted: + table.set( + "db", + exp.Identifier(this=parts[0]).update_positions( + previous_this + ), + ) + table.set( + "this", + exp.Identifier(this=parts[1]).update_positions( + previous_this + ), + ) + + if isinstance(table.this, exp.Identifier) and any( + "." in p.name for p in table.parts + ): + alias = table.this + catalog, db, this, *rest = ( + exp.to_identifier(p, quoted=True) + for p in split_num_words( + ".".join(p.name for p in table.parts), ".", 3 + ) + ) + + for part in (catalog, db, this): + if part: + part.update_positions(table.this) + + if rest and this: + this = exp.Dot.build([this, *rest]) # type: ignore + + table = exp.Table( + this=this, db=db, catalog=catalog, pivots=table.args.get("pivots") + ) + table.meta["quoted_table"] = True + else: + alias = None + + # The `INFORMATION_SCHEMA` views in BigQuery need to be qualified by a region or + # dataset, so if the project identifier is omitted we need to fix the ast so that + # the `INFORMATION_SCHEMA.X` bit is represented as a single (quoted) Identifier. + # Otherwise, we wouldn't correctly qualify a `Table` node that references these + # views, because it would seem like the "catalog" part is set, when it'd actually + # be the region/dataset. Merging the two identifiers into a single one is done to + # avoid producing a 4-part Table reference, which would cause issues in the schema + # module, when there are 3-part table names mixed with information schema views. + # + # See: https://cloud.google.com/bigquery/docs/information-schema-intro#syntax + table_parts = table.parts + if ( + len(table_parts) > 1 + and table_parts[-2].name.upper() == "INFORMATION_SCHEMA" + ): + # We need to alias the table here to avoid breaking existing qualified columns. + # This is expected to be safe, because if there's an actual alias coming up in + # the token stream, it will overwrite this one. If there isn't one, we are only + # exposing the name that can be used to reference the view explicitly (a no-op). + exp.alias_( + table, + t.cast(exp.Identifier, alias or table_parts[-1]), + table=True, + copy=False, + ) + + info_schema_view = f"{table_parts[-2].name}.{table_parts[-1].name}" + new_this = exp.Identifier( + this=info_schema_view, quoted=True + ).update_positions( + line=table_parts[-2].meta.get("line"), + col=table_parts[-1].meta.get("col"), + start=table_parts[-2].meta.get("start"), + end=table_parts[-1].meta.get("end"), + ) + table.set("this", new_this) + table.set("db", seq_get(table_parts, -3)) + table.set("catalog", seq_get(table_parts, -4)) + + return table + + def _parse_column(self) -> t.Optional[exp.Expression]: + column = super()._parse_column() + if isinstance(column, exp.Column): + parts = column.parts + if any("." in p.name for p in parts): + catalog, db, table, this, *rest = ( + exp.to_identifier(p, quoted=True) + for p in split_num_words( + ".".join(p.name for p in parts), ".", 4 + ) + ) + + if rest and this: + this = exp.Dot.build([this, *rest]) # type: ignore + + column = exp.Column(this=this, table=table, db=db, catalog=catalog) + column.meta["quoted_column"] = True + + return column + + @t.overload + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: + ... + + @t.overload + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: + ... + + def _parse_json_object(self, agg=False): + json_object = super()._parse_json_object() + array_kv_pair = seq_get(json_object.expressions, 0) + + # Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation + # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2 + if ( + array_kv_pair + and isinstance(array_kv_pair.this, exp.Array) + and isinstance(array_kv_pair.expression, exp.Array) + ): + keys = array_kv_pair.this.expressions + values = array_kv_pair.expression.expressions + + json_object.set( + "expressions", + [ + exp.JSONKeyValue(this=k, expression=v) + for k, v in zip(keys, values) + ], + ) + + return json_object + + def _parse_bracket( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + bracket = super()._parse_bracket(this) + + if isinstance(bracket, exp.Array): + bracket.set("struct_name_inheritance", True) + + if this is bracket: + return bracket + + if isinstance(bracket, exp.Bracket): + for expression in bracket.expressions: + name = expression.name.upper() + + if name not in self.BRACKET_OFFSETS: + break + + offset, safe = self.BRACKET_OFFSETS[name] + bracket.set("offset", offset) + bracket.set("safe", safe) + expression.replace(expression.expressions[0]) + + return bracket + + def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: + unnest = super()._parse_unnest(with_alias=with_alias) + + if not unnest: + return None + + unnest_expr = seq_get(unnest.expressions, 0) + if unnest_expr: + from bigframes_vendored.sqlglot.optimizer.annotate_types import ( + annotate_types, + ) + + unnest_expr = annotate_types(unnest_expr, dialect=self.dialect) + + # Unnesting a nested array (i.e array of structs) explodes the top-level struct fields, + # in contrast to other dialects such as DuckDB which flattens only the array by default + if unnest_expr.is_type(exp.DataType.Type.ARRAY) and any( + array_elem.is_type(exp.DataType.Type.STRUCT) + for array_elem in unnest_expr._type.expressions + ): + unnest.set("explode_array", True) + + return unnest + + def _parse_make_interval(self) -> exp.MakeInterval: + expr = exp.MakeInterval() + + for arg_key in MAKE_INTERVAL_KWARGS: + value = self._parse_lambda() + + if not value: + break + + # Non-named arguments are filled sequentially, (optionally) followed by named arguments + # that can appear in any order e.g MAKE_INTERVAL(1, minute => 5, day => 2) + if isinstance(value, exp.Kwarg): + arg_key = value.this.name + + expr.set(arg_key, value) + + self._match(TokenType.COMMA) + + return expr + + def _parse_ml(self, expr_type: t.Type[E], **kwargs) -> E: + self._match_text_seq("MODEL") + this = self._parse_table() + + self._match(TokenType.COMMA) + self._match_text_seq("TABLE") + + # Certain functions like ML.FORECAST require a STRUCT argument but not a TABLE/SELECT one + expression = ( + self._parse_table() + if not self._match(TokenType.STRUCT, advance=False) + else None + ) + + self._match(TokenType.COMMA) + + return self.expression( + expr_type, + this=this, + expression=expression, + params_struct=self._parse_bitwise(), + **kwargs, + ) + + def _parse_translate(self) -> exp.Translate | exp.MLTranslate: + # Check if this is ML.TRANSLATE by looking at previous tokens + token = seq_get(self._tokens, self._index - 4) + if token and token.text.upper() == "ML": + return self._parse_ml(exp.MLTranslate) + + return exp.Translate.from_arg_list(self._parse_function_args()) + + def _parse_features_at_time(self) -> exp.FeaturesAtTime: + self._match(TokenType.TABLE) + this = self._parse_table() + + expr = self.expression(exp.FeaturesAtTime, this=this) + + while self._match(TokenType.COMMA): + arg = self._parse_lambda() + + # Get the LHS of the Kwarg and set the arg to that value, e.g + # "num_rows => 1" sets the expr's `num_rows` arg + if arg: + expr.set(arg.this.name, arg) + + return expr + + def _parse_vector_search(self) -> exp.VectorSearch: + self._match(TokenType.TABLE) + base_table = self._parse_table() + + self._match(TokenType.COMMA) + + column_to_search = self._parse_bitwise() + self._match(TokenType.COMMA) + + self._match(TokenType.TABLE) + query_table = self._parse_table() + + expr = self.expression( + exp.VectorSearch, + this=base_table, + column_to_search=column_to_search, + query_table=query_table, + ) + + while self._match(TokenType.COMMA): + # query_column_to_search can be named argument or positional + if self._match(TokenType.STRING, advance=False): + query_column = self._parse_string() + expr.set("query_column_to_search", query_column) + else: + arg = self._parse_lambda() + if arg: + expr.set(arg.this.name, arg) + + return expr + + def _parse_export_data(self) -> exp.Export: + self._match_text_seq("DATA") + + return self.expression( + exp.Export, + connection=self._match_text_seq("WITH", "CONNECTION") + and self._parse_table_parts(), + options=self._parse_properties(), + this=self._match_text_seq("AS") and self._parse_select(), + ) + + def _parse_column_ops( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + this = super()._parse_column_ops(this) + + if isinstance(this, exp.Dot): + prefix_name = this.this.name.upper() + func_name = this.name.upper() + if prefix_name == "NET": + if func_name == "HOST": + this = self.expression( + exp.NetHost, this=seq_get(this.expression.expressions, 0) + ) + elif prefix_name == "SAFE": + if func_name == "TIMESTAMP": + this = _build_timestamp(this.expression.expressions) + this.set("safe", True) + + return this + + class Generator(generator.Generator): + INTERVAL_ALLOWS_PLURAL_FORM = False + JOIN_HINTS = False + QUERY_HINTS = False + TABLE_HINTS = False + LIMIT_FETCH = "LIMIT" + RENAME_TABLE_WITH_DB = False + NVL2_SUPPORTED = False + UNNEST_WITH_ORDINALITY = False + COLLATE_IS_FUNC = True + LIMIT_ONLY_LITERALS = True + SUPPORTS_TABLE_ALIAS_COLUMNS = False + UNPIVOT_ALIASES_ARE_IDENTIFIERS = False + JSON_KEY_VALUE_PAIR_SEP = "," + NULL_ORDERING_SUPPORTED = False + IGNORE_NULLS_IN_FUNC = True + JSON_PATH_SINGLE_QUOTE_ESCAPE = True + CAN_IMPLEMENT_ARRAY_ANY = True + SUPPORTS_TO_NUMBER = False + NAMED_PLACEHOLDER_TOKEN = "@" + HEX_FUNC = "TO_HEX" + WITH_PROPERTIES_PREFIX = "OPTIONS" + SUPPORTS_EXPLODING_PROJECTIONS = False + EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False + SUPPORTS_UNIX_SECONDS = True + + SAFE_JSON_PATH_KEY_RE = re.compile(r"^[_\-a-zA-Z][\-\w]*$") + + TS_OR_DS_TYPES = ( + exp.TsOrDsToDatetime, + exp.TsOrDsToTimestamp, + exp.TsOrDsToTime, + exp.TsOrDsToDate, + ) + + TRANSFORMS = { + **generator.Generator.TRANSFORMS, + exp.ApproxTopK: rename_func("APPROX_TOP_COUNT"), + exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), + exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), + exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), + exp.Array: inline_array_unless_query, + exp.ArrayContains: _array_contains_sql, + exp.ArrayFilter: filter_array_using_unnest, + exp.ArrayRemove: filter_array_using_unnest, + exp.BitwiseAndAgg: rename_func("BIT_AND"), + exp.BitwiseOrAgg: rename_func("BIT_OR"), + exp.BitwiseXorAgg: rename_func("BIT_XOR"), + exp.BitwiseCount: rename_func("BIT_COUNT"), + exp.ByteLength: rename_func("BYTE_LENGTH"), + exp.Cast: transforms.preprocess( + [transforms.remove_precision_parameterized_types] + ), + exp.CollateProperty: lambda self, e: ( + f"DEFAULT COLLATE {self.sql(e, 'this')}" + if e.args.get("default") + else f"COLLATE {self.sql(e, 'this')}" + ), + exp.Commit: lambda *_: "COMMIT TRANSACTION", + exp.CountIf: rename_func("COUNTIF"), + exp.Create: _create_sql, + exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), + exp.DateAdd: date_add_interval_sql("DATE", "ADD"), + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", e.this, e.expression, unit_to_var(e) + ), + exp.DateFromParts: rename_func("DATE"), + exp.DateStrToDate: datestrtodate_sql, + exp.DateSub: date_add_interval_sql("DATE", "SUB"), + exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"), + exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"), + exp.DateFromUnixDate: rename_func("DATE_FROM_UNIX_DATE"), + exp.FromTimeZone: lambda self, e: self.func( + "DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'" + ), + exp.GenerateSeries: rename_func("GENERATE_ARRAY"), + exp.GroupConcat: lambda self, e: groupconcat_sql( + self, e, func_name="STRING_AGG", within_group=False, sep=None + ), + exp.Hex: lambda self, e: self.func( + "UPPER", self.func("TO_HEX", self.sql(e, "this")) + ), + exp.HexString: lambda self, e: self.hexstring_sql( + e, binary_function_repr="FROM_HEX" + ), + exp.If: if_sql(false_value="NULL"), + exp.ILike: no_ilike_sql, + exp.IntDiv: rename_func("DIV"), + exp.Int64: rename_func("INT64"), + exp.JSONBool: rename_func("BOOL"), + exp.JSONExtract: _json_extract_sql, + exp.JSONExtractArray: _json_extract_sql, + exp.JSONExtractScalar: _json_extract_sql, + exp.JSONFormat: lambda self, e: self.func( + "TO_JSON" if e.args.get("to_json") else "TO_JSON_STRING", + e.this, + e.args.get("options"), + ), + exp.JSONKeysAtDepth: rename_func("JSON_KEYS"), + exp.JSONValueArray: rename_func("JSON_VALUE_ARRAY"), + exp.Levenshtein: _levenshtein_sql, + exp.Max: max_or_greatest, + exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), + exp.MD5Digest: rename_func("MD5"), + exp.Min: min_or_least, + exp.Normalize: lambda self, e: self.func( + "NORMALIZE_AND_CASEFOLD" if e.args.get("is_casefold") else "NORMALIZE", + e.this, + e.args.get("form"), + ), + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.RegexpExtract: lambda self, e: self.func( + "REGEXP_EXTRACT", + e.this, + e.expression, + e.args.get("position"), + e.args.get("occurrence"), + ), + exp.RegexpExtractAll: lambda self, e: self.func( + "REGEXP_EXTRACT_ALL", e.this, e.expression + ), + exp.RegexpReplace: regexp_replace_sql, + exp.RegexpLike: rename_func("REGEXP_CONTAINS"), + exp.ReturnsProperty: _returnsproperty_sql, + exp.Rollback: lambda *_: "ROLLBACK TRANSACTION", + exp.ParseTime: lambda self, e: self.func( + "PARSE_TIME", self.format_time(e), e.this + ), + exp.ParseDatetime: lambda self, e: self.func( + "PARSE_DATETIME", self.format_time(e), e.this + ), + exp.Select: transforms.preprocess( + [ + transforms.explode_projection_to_unnest(), + transforms.unqualify_unnest, + transforms.eliminate_distinct_on, + _alias_ordered_group, + transforms.eliminate_semi_and_anti_joins, + ] + ), + exp.SHA: rename_func("SHA1"), + exp.SHA2: sha256_sql, + exp.SHA1Digest: rename_func("SHA1"), + exp.SHA2Digest: sha2_digest_sql, + exp.StabilityProperty: lambda self, e: ( + "DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" + ), + exp.String: rename_func("STRING"), + exp.StrPosition: lambda self, e: ( + strposition_sql( + self, + e, + func_name="INSTR", + supports_position=True, + supports_occurrence=True, + ) + ), + exp.StrToDate: _str_to_datetime_sql, + exp.StrToTime: _str_to_datetime_sql, + exp.SessionUser: lambda *_: "SESSION_USER()", + exp.TimeAdd: date_add_interval_sql("TIME", "ADD"), + exp.TimeFromParts: rename_func("TIME"), + exp.TimestampFromParts: rename_func("DATETIME"), + exp.TimeSub: date_add_interval_sql("TIME", "SUB"), + exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"), + exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"), + exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), + exp.TimeStrToTime: timestrtotime_sql, + exp.Transaction: lambda *_: "BEGIN TRANSACTION", + exp.TsOrDsAdd: _ts_or_ds_add_sql, + exp.TsOrDsDiff: _ts_or_ds_diff_sql, + exp.TsOrDsToTime: rename_func("TIME"), + exp.TsOrDsToDatetime: rename_func("DATETIME"), + exp.TsOrDsToTimestamp: rename_func("TIMESTAMP"), + exp.Unhex: rename_func("FROM_HEX"), + exp.UnixDate: rename_func("UNIX_DATE"), + exp.UnixToTime: _unix_to_time_sql, + exp.Uuid: lambda *_: "GENERATE_UUID()", + exp.Values: _derived_table_values_to_unnest, + exp.VariancePop: rename_func("VAR_POP"), + exp.SafeDivide: rename_func("SAFE_DIVIDE"), + } + + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } + + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC", + exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.BINARY: "BYTES", + exp.DataType.Type.BLOB: "BYTES", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.CHAR: "STRING", + exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.DOUBLE: "FLOAT64", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.NCHAR: "STRING", + exp.DataType.Type.NVARCHAR: "STRING", + exp.DataType.Type.SMALLINT: "INT64", + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPNTZ: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", + exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.ROWVERSION: "BYTES", + exp.DataType.Type.UUID: "STRING", + exp.DataType.Type.VARBINARY: "BYTES", + exp.DataType.Type.VARCHAR: "STRING", + exp.DataType.Type.VARIANT: "ANY TYPE", + } + + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + + # WINDOW comes after QUALIFY + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#window_clause + AFTER_HAVING_MODIFIER_TRANSFORMS = { + "qualify": generator.Generator.AFTER_HAVING_MODIFIER_TRANSFORMS["qualify"], + "windows": generator.Generator.AFTER_HAVING_MODIFIER_TRANSFORMS["windows"], + } + + # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords + RESERVED_KEYWORDS = { + "all", + "and", + "any", + "array", + "as", + "asc", + "assert_rows_modified", + "at", + "between", + "by", + "case", + "cast", + "collate", + "contains", + "create", + "cross", + "cube", + "current", + "default", + "define", + "desc", + "distinct", + "else", + "end", + "enum", + "escape", + "except", + "exclude", + "exists", + "extract", + "false", + "fetch", + "following", + "for", + "from", + "full", + "group", + "grouping", + "groups", + "hash", + "having", + "if", + "ignore", + "in", + "inner", + "intersect", + "interval", + "into", + "is", + "join", + "lateral", + "left", + "like", + "limit", + "lookup", + "merge", + "natural", + "new", + "no", + "not", + "null", + "nulls", + "of", + "on", + "or", + "order", + "outer", + "over", + "partition", + "preceding", + "proto", + "qualify", + "range", + "recursive", + "respect", + "right", + "rollup", + "rows", + "select", + "set", + "some", + "struct", + "tablesample", + "then", + "to", + "treat", + "true", + "unbounded", + "union", + "unnest", + "using", + "when", + "where", + "window", + "with", + "within", + } + + def datetrunc_sql(self, expression: exp.DateTrunc) -> str: + unit = expression.unit + unit_sql = unit.name if unit.is_string else self.sql(unit) + return self.func( + "DATE_TRUNC", expression.this, unit_sql, expression.args.get("zone") + ) + + def mod_sql(self, expression: exp.Mod) -> str: + this = expression.this + expr = expression.expression + return self.func( + "MOD", + this.unnest() if isinstance(this, exp.Paren) else this, + expr.unnest() if isinstance(expr, exp.Paren) else expr, + ) + + def column_parts(self, expression: exp.Column) -> str: + if expression.meta.get("quoted_column"): + # If a column reference is of the form `dataset.table`.name, we need + # to preserve the quoted table path, otherwise the reference breaks + table_parts = ".".join(p.name for p in expression.parts[:-1]) + table_path = self.sql(exp.Identifier(this=table_parts, quoted=True)) + return f"{table_path}.{self.sql(expression, 'this')}" + + return super().column_parts(expression) + + def table_parts(self, expression: exp.Table) -> str: + # Depending on the context, `x.y` may not resolve to the same data source as `x`.`y`, so + # we need to make sure the correct quoting is used in each case. + # + # For example, if there is a CTE x that clashes with a schema name, then the former will + # return the table y in that schema, whereas the latter will return the CTE's y column: + # + # - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x.y` -> cross join + # - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x`.`y` -> implicit unnest + if expression.meta.get("quoted_table"): + table_parts = ".".join(p.name for p in expression.parts) + return self.sql(exp.Identifier(this=table_parts, quoted=True)) + + return super().table_parts(expression) + + def timetostr_sql(self, expression: exp.TimeToStr) -> str: + this = expression.this + if isinstance(this, exp.TsOrDsToDatetime): + func_name = "FORMAT_DATETIME" + elif isinstance(this, exp.TsOrDsToTimestamp): + func_name = "FORMAT_TIMESTAMP" + elif isinstance(this, exp.TsOrDsToTime): + func_name = "FORMAT_TIME" + else: + func_name = "FORMAT_DATE" + + time_expr = this if isinstance(this, self.TS_OR_DS_TYPES) else expression + return self.func( + func_name, + self.format_time(expression), + time_expr.this, + expression.args.get("zone"), + ) + + def eq_sql(self, expression: exp.EQ) -> str: + # Operands of = cannot be NULL in BigQuery + if isinstance(expression.left, exp.Null) or isinstance( + expression.right, exp.Null + ): + if not isinstance(expression.parent, exp.Update): + return "NULL" + + return self.binary(expression, "=") + + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: + parent = expression.parent + + # BigQuery allows CAST(.. AS {STRING|TIMESTAMP} [FORMAT [AT TIME ZONE ]]). + # Only the TIMESTAMP one should use the below conversion, when AT TIME ZONE is included. + if not isinstance(parent, exp.Cast) or not parent.to.is_type("text"): + return self.func( + "TIMESTAMP", + self.func("DATETIME", expression.this, expression.args.get("zone")), + ) + + return super().attimezone_sql(expression) + + def trycast_sql(self, expression: exp.TryCast) -> str: + return self.cast_sql(expression, safe_prefix="SAFE_") + + def bracket_sql(self, expression: exp.Bracket) -> str: + this = expression.this + expressions = expression.expressions + + if ( + len(expressions) == 1 + and this + and this.is_type(exp.DataType.Type.STRUCT) + ): + arg = expressions[0] + if arg.type is None: + from bigframes_vendored.sqlglot.optimizer.annotate_types import ( + annotate_types, + ) + + arg = annotate_types(arg, dialect=self.dialect) + + if arg.type and arg.type.this in exp.DataType.TEXT_TYPES: + # BQ doesn't support bracket syntax with string values for structs + return f"{self.sql(this)}.{arg.name}" + + expressions_sql = self.expressions(expression, flat=True) + offset = expression.args.get("offset") + + if offset == 0: + expressions_sql = f"OFFSET({expressions_sql})" + elif offset == 1: + expressions_sql = f"ORDINAL({expressions_sql})" + elif offset is not None: + self.unsupported(f"Unsupported array offset: {offset}") + + if expression.args.get("safe"): + expressions_sql = f"SAFE_{expressions_sql}" + + return f"{self.sql(this)}[{expressions_sql}]" + + def in_unnest_op(self, expression: exp.Unnest) -> str: + return self.sql(expression) + + def version_sql(self, expression: exp.Version) -> str: + if expression.name == "TIMESTAMP": + expression.set("this", "SYSTEM_TIME") + return super().version_sql(expression) + + def contains_sql(self, expression: exp.Contains) -> str: + this = expression.this + expr = expression.expression + + if isinstance(this, exp.Lower) and isinstance(expr, exp.Lower): + this = this.this + expr = expr.this + + return self.func( + "CONTAINS_SUBSTR", this, expr, expression.args.get("json_scope") + ) + + def cast_sql( + self, expression: exp.Cast, safe_prefix: t.Optional[str] = None + ) -> str: + this = expression.this + + # This ensures that inline type-annotated ARRAY literals like ARRAY[1, 2, 3] + # are roundtripped unaffected. The inner check excludes ARRAY(SELECT ...) expressions, + # because they aren't literals and so the above syntax is invalid BigQuery. + if isinstance(this, exp.Array): + elem = seq_get(this.expressions, 0) + if not (elem and elem.find(exp.Query)): + return f"{self.sql(expression, 'to')}{self.sql(this)}" + + return super().cast_sql(expression, safe_prefix=safe_prefix) + + def declareitem_sql(self, expression: exp.DeclareItem) -> str: + variables = self.expressions(expression, "this") + default = self.sql(expression, "default") + default = f" DEFAULT {default}" if default else "" + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + + return f"{variables}{kind}{default}" + + def timestamp_sql(self, expression: exp.Timestamp) -> str: + prefix = "SAFE." if expression.args.get("safe") else "" + return self.func( + f"{prefix}TIMESTAMP", expression.this, expression.args.get("zone") + ) diff --git a/third_party/bigframes_vendored/sqlglot/dialects/dialect.py b/third_party/bigframes_vendored/sqlglot/dialects/dialect.py new file mode 100644 index 0000000000..449a6e2494 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/dialects/dialect.py @@ -0,0 +1,2359 @@ +from __future__ import annotations + +from enum import auto, Enum +from functools import reduce +import importlib +import logging +import sys +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.dialects import DIALECT_MODULE_NAMES +from bigframes_vendored.sqlglot.errors import ParseError +from bigframes_vendored.sqlglot.generator import Generator, unsupported_args +from bigframes_vendored.sqlglot.helper import ( + AutoName, + flatten, + is_int, + seq_get, + suggest_closest_match_and_fail, + to_bool, +) +from bigframes_vendored.sqlglot.jsonpath import JSONPathTokenizer +from bigframes_vendored.sqlglot.jsonpath import parse as parse_json_path +from bigframes_vendored.sqlglot.parser import Parser +from bigframes_vendored.sqlglot.time import format_time, subsecond_precision, TIMEZONES +from bigframes_vendored.sqlglot.tokens import Token, Tokenizer, TokenType +from bigframes_vendored.sqlglot.trie import new_trie +from bigframes_vendored.sqlglot.typing import EXPRESSION_METADATA + +DATE_ADD_OR_DIFF = t.Union[ + exp.DateAdd, + exp.DateDiff, + exp.DateSub, + exp.TsOrDsAdd, + exp.TsOrDsDiff, +] +DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] +JSON_EXTRACT_TYPE = t.Union[ + exp.JSONExtract, exp.JSONExtractScalar, exp.JSONBExtract, exp.JSONBExtractScalar +] +DATETIME_DELTA = t.Union[ + exp.DateAdd, + exp.DatetimeAdd, + exp.DatetimeSub, + exp.TimeAdd, + exp.TimeSub, + exp.TimestampAdd, + exp.TimestampSub, + exp.TsOrDsAdd, +] +DATETIME_ADD = ( + exp.DateAdd, + exp.TimeAdd, + exp.DatetimeAdd, + exp.TsOrDsAdd, + exp.TimestampAdd, +) + +if t.TYPE_CHECKING: + from sqlglot._typing import B, E, F + +logger = logging.getLogger("sqlglot") + +UNESCAPED_SEQUENCES = { + "\\a": "\a", + "\\b": "\b", + "\\f": "\f", + "\\n": "\n", + "\\r": "\r", + "\\t": "\t", + "\\v": "\v", + "\\\\": "\\", +} + + +class Dialects(str, Enum): + """Dialects supported by SQLGLot.""" + + DIALECT = "" + + ATHENA = "athena" + BIGQUERY = "bigquery" + CLICKHOUSE = "clickhouse" + DATABRICKS = "databricks" + DORIS = "doris" + DREMIO = "dremio" + DRILL = "drill" + DRUID = "druid" + DUCKDB = "duckdb" + DUNE = "dune" + FABRIC = "fabric" + HIVE = "hive" + MATERIALIZE = "materialize" + MYSQL = "mysql" + ORACLE = "oracle" + POSTGRES = "postgres" + PRESTO = "presto" + PRQL = "prql" + REDSHIFT = "redshift" + RISINGWAVE = "risingwave" + SNOWFLAKE = "snowflake" + SOLR = "solr" + SPARK = "spark" + SPARK2 = "spark2" + SQLITE = "sqlite" + STARROCKS = "starrocks" + TABLEAU = "tableau" + TERADATA = "teradata" + TRINO = "trino" + TSQL = "tsql" + EXASOL = "exasol" + + +class NormalizationStrategy(str, AutoName): + """Specifies the strategy according to which identifiers should be normalized.""" + + LOWERCASE = auto() + """Unquoted identifiers are lowercased.""" + + UPPERCASE = auto() + """Unquoted identifiers are uppercased.""" + + CASE_SENSITIVE = auto() + """Always case-sensitive, regardless of quotes.""" + + CASE_INSENSITIVE = auto() + """Always case-insensitive (lowercase), regardless of quotes.""" + + CASE_INSENSITIVE_UPPERCASE = auto() + """Always case-insensitive (uppercase), regardless of quotes.""" + + +class _Dialect(type): + _classes: t.Dict[str, t.Type[Dialect]] = {} + + def __eq__(cls, other: t.Any) -> bool: + if cls is other: + return True + if isinstance(other, str): + return cls is cls.get(other) + if isinstance(other, Dialect): + return cls is type(other) + + return False + + def __hash__(cls) -> int: + return hash(cls.__name__.lower()) + + @property + def classes(cls): + if len(DIALECT_MODULE_NAMES) != len(cls._classes): + for key in DIALECT_MODULE_NAMES: + cls._try_load(key) + + return cls._classes + + @classmethod + def _try_load(cls, key: str | Dialects) -> None: + if isinstance(key, Dialects): + key = key.value + + # This import will lead to a new dialect being loaded, and hence, registered. + # We check that the key is an actual sqlglot module to avoid blindly importing + # files. Custom user dialects need to be imported at the top-level package, in + # order for them to be registered as soon as possible. + if key in DIALECT_MODULE_NAMES: + importlib.import_module(f"sqlglot.dialects.{key}") + + @classmethod + def __getitem__(cls, key: str) -> t.Type[Dialect]: + if key not in cls._classes: + cls._try_load(key) + + return cls._classes[key] + + @classmethod + def get( + cls, key: str, default: t.Optional[t.Type[Dialect]] = None + ) -> t.Optional[t.Type[Dialect]]: + if key not in cls._classes: + cls._try_load(key) + + return cls._classes.get(key, default) + + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + enum = Dialects.__members__.get(clsname.upper()) + cls._classes[enum.value if enum is not None else clsname.lower()] = klass + + klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) + klass.FORMAT_TRIE = ( + new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE + ) + # Merge class-defined INVERSE_TIME_MAPPING with auto-generated mappings + # This allows dialects to define custom inverse mappings for roundtrip correctness + klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} | ( + klass.__dict__.get("INVERSE_TIME_MAPPING") or {} + ) + klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) + klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} + klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) + + klass.INVERSE_CREATABLE_KIND_MAPPING = { + v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() + } + + base = seq_get(bases, 0) + base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) + base_jsonpath_tokenizer = ( + getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer), + ) + base_parser = (getattr(base, "parser_class", Parser),) + base_generator = (getattr(base, "generator_class", Generator),) + + klass.tokenizer_class = klass.__dict__.get( + "Tokenizer", type("Tokenizer", base_tokenizer, {}) + ) + klass.jsonpath_tokenizer_class = klass.__dict__.get( + "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) + ) + klass.parser_class = klass.__dict__.get( + "Parser", type("Parser", base_parser, {}) + ) + klass.generator_class = klass.__dict__.get( + "Generator", type("Generator", base_generator, {}) + ) + + klass.QUOTE_START, klass.QUOTE_END = list( + klass.tokenizer_class._QUOTES.items() + )[0] + klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( + klass.tokenizer_class._IDENTIFIERS.items() + )[0] + + def get_start_end( + token_type: TokenType, + ) -> t.Tuple[t.Optional[str], t.Optional[str]]: + return next( + ( + (s, e) + for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() + if t == token_type + ), + (None, None), + ) + + klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) + klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) + klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) + klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) + + if "\\" in klass.tokenizer_class.STRING_ESCAPES: + klass.UNESCAPED_SEQUENCES = { + **UNESCAPED_SEQUENCES, + **klass.UNESCAPED_SEQUENCES, + } + + klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} + + klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS + + if enum not in ("", "bigquery", "snowflake"): + klass.INITCAP_SUPPORTS_CUSTOM_DELIMITERS = False + + if enum not in ("", "bigquery"): + klass.generator_class.SELECT_KINDS = () + + if enum not in ("", "athena", "presto", "trino", "duckdb"): + klass.generator_class.TRY_SUPPORTED = False + klass.generator_class.SUPPORTS_UESCAPE = False + + if enum not in ("", "databricks", "hive", "spark", "spark2"): + modifier_transforms = ( + klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() + ) + for modifier in ("cluster", "distribute", "sort"): + modifier_transforms.pop(modifier, None) + + klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms + + if enum not in ("", "doris", "mysql"): + klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { + TokenType.STRAIGHT_JOIN, + } + klass.parser_class.TABLE_ALIAS_TOKENS = ( + klass.parser_class.TABLE_ALIAS_TOKENS + | { + TokenType.STRAIGHT_JOIN, + } + ) + + if enum not in ("", "databricks", "oracle", "redshift", "snowflake", "spark"): + klass.generator_class.SUPPORTS_DECODE_CASE = False + + if not klass.SUPPORTS_SEMI_ANTI_JOIN: + klass.parser_class.TABLE_ALIAS_TOKENS = ( + klass.parser_class.TABLE_ALIAS_TOKENS + | { + TokenType.ANTI, + TokenType.SEMI, + } + ) + + if enum not in ( + "", + "postgres", + "duckdb", + "redshift", + "snowflake", + "presto", + "trino", + "mysql", + "singlestore", + ): + no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy() + no_paren_functions.pop(TokenType.LOCALTIME, None) + if enum != "oracle": + no_paren_functions.pop(TokenType.LOCALTIMESTAMP, None) + klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions + + if enum in ( + "", + "postgres", + "duckdb", + "trino", + ): + no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy() + no_paren_functions[TokenType.CURRENT_CATALOG] = exp.CurrentCatalog + klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions + else: + # For dialects that don't support this keyword, treat it as a regular identifier + # This fixes the "Unexpected token" error in BQ, Spark, etc. + klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { + TokenType.CURRENT_CATALOG, + } + + if enum in ( + "", + "duckdb", + "spark", + "postgres", + "tsql", + ): + no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy() + no_paren_functions[TokenType.SESSION_USER] = exp.SessionUser + klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions + else: + klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { + TokenType.SESSION_USER, + } + + klass.VALID_INTERVAL_UNITS = { + *klass.VALID_INTERVAL_UNITS, + *klass.DATE_PART_MAPPING.keys(), + *klass.DATE_PART_MAPPING.values(), + } + + return klass + + +class Dialect(metaclass=_Dialect): + INDEX_OFFSET = 0 + """The base index offset for arrays.""" + + WEEK_OFFSET = 0 + """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" + + UNNEST_COLUMN_ONLY = False + """Whether `UNNEST` table aliases are treated as column aliases.""" + + ALIAS_POST_TABLESAMPLE = False + """Whether the table alias comes after tablesample.""" + + TABLESAMPLE_SIZE_IS_PERCENT = False + """Whether a size in the table sample clause represents percentage.""" + + NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE + """Specifies the strategy according to which identifiers should be normalized.""" + + IDENTIFIERS_CAN_START_WITH_DIGIT = False + """Whether an unquoted identifier can start with a digit.""" + + DPIPE_IS_STRING_CONCAT = True + """Whether the DPIPE token (`||`) is a string concatenation operator.""" + + STRICT_STRING_CONCAT = False + """Whether `CONCAT`'s arguments must be strings.""" + + SUPPORTS_USER_DEFINED_TYPES = True + """Whether user-defined data types are supported.""" + + SUPPORTS_SEMI_ANTI_JOIN = True + """Whether `SEMI` or `ANTI` joins are supported.""" + + SUPPORTS_COLUMN_JOIN_MARKS = False + """Whether the old-style outer join (+) syntax is supported.""" + + COPY_PARAMS_ARE_CSV = True + """Separator of COPY statement parameters.""" + + NORMALIZE_FUNCTIONS: bool | str = "upper" + """ + Determines how function names are going to be normalized. + Possible values: + "upper" or True: Convert names to uppercase. + "lower": Convert names to lowercase. + False: Disables function name normalization. + """ + + PRESERVE_ORIGINAL_NAMES: bool = False + """ + Whether the name of the function should be preserved inside the node's metadata, + can be useful for roundtripping deprecated vs new functions that share an AST node + e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery + """ + + LOG_BASE_FIRST: t.Optional[bool] = True + """ + Whether the base comes first in the `LOG` function. + Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) + """ + + NULL_ORDERING = "nulls_are_small" + """ + Default `NULL` ordering method to use if not explicitly set. + Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` + """ + + TYPED_DIVISION = False + """ + Whether the behavior of `a / b` depends on the types of `a` and `b`. + False means `a / b` is always float division. + True means `a / b` is integer division if both `a` and `b` are integers. + """ + + SAFE_DIVISION = False + """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" + + CONCAT_COALESCE = False + """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" + + HEX_LOWERCASE = False + """Whether the `HEX` function returns a lowercase hexadecimal string.""" + + DATE_FORMAT = "'%Y-%m-%d'" + DATEINT_FORMAT = "'%Y%m%d'" + TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" + + TIME_MAPPING: t.Dict[str, str] = {} + """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time + # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE + FORMAT_MAPPING: t.Dict[str, str] = {} + """ + Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. + If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. + """ + + UNESCAPED_SEQUENCES: t.Dict[str, str] = {} + """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" + + PSEUDOCOLUMNS: t.Set[str] = set() + """ + Columns that are auto-generated by the engine corresponding to this dialect. + For example, such columns may be excluded from `SELECT *` queries. + """ + + PREFER_CTE_ALIAS_COLUMN = False + """ + Some dialects, such as Snowflake, allow you to reference a CTE column alias in the + HAVING clause of the CTE. This flag will cause the CTE alias columns to override + any projection aliases in the subquery. + + For example, + WITH y(c) AS ( + SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 + ) SELECT c FROM y; + + will be rewritten as + + WITH y(c) AS ( + SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 + ) SELECT c FROM y; + """ + + COPY_PARAMS_ARE_CSV = True + """ + Whether COPY statement parameters are separated by comma or whitespace + """ + + FORCE_EARLY_ALIAS_REF_EXPANSION = False + """ + Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). + + For example: + WITH data AS ( + SELECT + 1 AS id, + 2 AS my_id + ) + SELECT + id AS my_id + FROM + data + WHERE + my_id = 1 + GROUP BY + my_id, + HAVING + my_id = 1 + + In most dialects, "my_id" would refer to "data.my_id" across the query, except: + - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e + it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" + - Clickhouse, which will forward the alias across the query i.e it resolves + to "WHERE id = 1 GROUP BY id HAVING id = 1" + """ + + EXPAND_ONLY_GROUP_ALIAS_REF = False + """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" + + ANNOTATE_ALL_SCOPES = False + """Whether to annotate all scopes during optimization. Used by BigQuery for UNNEST support.""" + + DISABLES_ALIAS_REF_EXPANSION = False + """ + Whether alias reference expansion is disabled for this dialect. + + Some dialects like Oracle do NOT support referencing aliases in projections or WHERE clauses. + The original expression must be repeated instead. + + For example, in Oracle: + SELECT y.foo AS bar, bar * 2 AS baz FROM y -- INVALID + SELECT y.foo AS bar, y.foo * 2 AS baz FROM y -- VALID + """ + + SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS = False + """ + Whether alias references are allowed in JOIN ... ON clauses. + + Most dialects do not support this, but Snowflake allows alias expansion in the JOIN ... ON + clause (and almost everywhere else) + + For example, in Snowflake: + SELECT a.id AS user_id FROM a JOIN b ON user_id = b.id -- VALID + + Reference: https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes + """ + + SUPPORTS_ORDER_BY_ALL = False + """ + Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks + """ + + PROJECTION_ALIASES_SHADOW_SOURCE_NAMES = False + """ + Whether projection alias names can shadow table/source names in GROUP BY and HAVING clauses. + + In BigQuery, when a projection alias has the same name as a source table, the alias takes + precedence in GROUP BY and HAVING clauses, and the table becomes inaccessible by that name. + + For example, in BigQuery: + SELECT id, ARRAY_AGG(col) AS custom_fields + FROM custom_fields + GROUP BY id + HAVING id >= 1 + + The "custom_fields" source is shadowed by the projection alias, so we cannot qualify "id" + with "custom_fields" in GROUP BY/HAVING. + """ + + TABLES_REFERENCEABLE_AS_COLUMNS = False + """ + Whether table names can be referenced as columns (treated as structs). + + BigQuery allows tables to be referenced as columns in queries, automatically treating + them as struct values containing all the table's columns. + + For example, in BigQuery: + SELECT t FROM my_table AS t -- Returns entire row as a struct + """ + + SUPPORTS_STRUCT_STAR_EXPANSION = False + """ + Whether the dialect supports expanding struct fields using star notation (e.g., struct_col.*). + + BigQuery allows struct fields to be expanded with the star operator: + SELECT t.struct_col.* FROM table t + RisingWave also allows struct field expansion with the star operator using parentheses: + SELECT (t.struct_col).* FROM table t + + This expands to all fields within the struct. + """ + + EXCLUDES_PSEUDOCOLUMNS_FROM_STAR = False + """ + Whether pseudocolumns should be excluded from star expansion (SELECT *). + + Pseudocolumns are special dialect-specific columns (e.g., Oracle's ROWNUM, ROWID, LEVEL, + or BigQuery's _PARTITIONTIME, _PARTITIONDATE) that are implicitly available but not part + of the table schema. When this is True, SELECT * will not include these pseudocolumns; + they must be explicitly selected. + """ + + QUERY_RESULTS_ARE_STRUCTS = False + """ + Whether query results are typed as structs in metadata for type inference. + + In BigQuery, subqueries store their column types as a STRUCT in metadata, + enabling special type inference for ARRAY(SELECT ...) expressions: + ARRAY(SELECT x, y FROM t) → ARRAY> + + For single column subqueries, BigQuery unwraps the struct: + ARRAY(SELECT x FROM t) → ARRAY + + This is metadata-only for type inference. + """ + + REQUIRES_PARENTHESIZED_STRUCT_ACCESS = False + """ + Whether struct field access requires parentheses around the expression. + + RisingWave requires parentheses for struct field access in certain contexts: + SELECT (col.field).subfield FROM table -- Parentheses required + + Without parentheses, the parser may not correctly interpret nested struct access. + + Reference: https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct + """ + + SUPPORTS_NULL_TYPE = False + """ + Whether NULL/VOID is supported as a valid data type (not just a value). + + Databricks and Spark v3+ support NULL as an actual type, allowing expressions like: + SELECT NULL AS col -- Has type NULL, not just value NULL + CAST(x AS VOID) -- Valid type cast + """ + + COALESCE_COMPARISON_NON_STANDARD = False + """ + Whether COALESCE in comparisons has non-standard NULL semantics. + + We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, + because they are not always equivalent. For example, if `x` is `NULL` and it comes + from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`. + + In standard SQL and most dialects, these expressions are equivalent, but Redshift treats + table NULLs differently in this context. + """ + + HAS_DISTINCT_ARRAY_CONSTRUCTORS = False + """ + Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) + as the former is of type INT[] vs the latter which is SUPER + """ + + SUPPORTS_FIXED_SIZE_ARRAYS = False + """ + Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. + in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should + be interpreted as a subscript/index operator. + """ + + STRICT_JSON_PATH_SYNTAX = True + """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" + + ON_CONDITION_EMPTY_BEFORE_ERROR = True + """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" + + ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True + """Whether ArrayAgg needs to filter NULL values.""" + + PROMOTE_TO_INFERRED_DATETIME_TYPE = False + """ + This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted + to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal + is cast to x's type to match it instead. + """ + + SUPPORTS_VALUES_DEFAULT = True + """Whether the DEFAULT keyword is supported in the VALUES clause.""" + + NUMBERS_CAN_BE_UNDERSCORE_SEPARATED = False + """Whether number literals can include underscores for better readability""" + + HEX_STRING_IS_INTEGER_TYPE: bool = False + """Whether hex strings such as x'CC' evaluate to integer or binary/blob type""" + + REGEXP_EXTRACT_DEFAULT_GROUP = 0 + """The default value for the capturing group.""" + + REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL = True + """Whether REGEXP_EXTRACT returns NULL when the position arg exceeds the string length.""" + + SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { + exp.Except: True, + exp.Intersect: True, + exp.Union: True, + } + """ + Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` + must be explicitly specified. + """ + + CREATABLE_KIND_MAPPING: dict[str, str] = {} + """ + Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse + equivalent of CREATE SCHEMA is CREATE DATABASE. + """ + + ALTER_TABLE_SUPPORTS_CASCADE = False + """ + Hive by default does not update the schema of existing partitions when a column is changed. + the CASCADE clause is used to indicate that the change should be propagated to all existing partitions. + the Spark dialect, while derived from Hive, does not support the CASCADE clause. + """ + + # Whether ADD is present for each column added by ALTER TABLE + ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = True + + # Whether the value/LHS of the TRY_CAST( AS ) should strictly be a + # STRING type (Snowflake's case) or can be of any type + TRY_CAST_REQUIRES_STRING: t.Optional[bool] = None + + # Whether the double negation can be applied + # Not safe with MySQL and SQLite due to type coercion (may not return boolean) + SAFE_TO_ELIMINATE_DOUBLE_NEGATION = True + + # Whether the INITCAP function supports custom delimiter characters as the second argument + # Default delimiter characters for INITCAP function: whitespace and non-alphanumeric characters + INITCAP_SUPPORTS_CUSTOM_DELIMITERS = True + INITCAP_DEFAULT_DELIMITER_CHARS = ( + " \t\n\r\f\v!\"#$%&'()*+,\\-./:;<=>?@\\[\\]^_`{|}~" + ) + + BYTE_STRING_IS_BYTES_TYPE: bool = False + """ + Whether byte string literals (ex: BigQuery's b'...') are typed as BYTES/BINARY + """ + + UUID_IS_STRING_TYPE: bool = False + """ + Whether a UUID is considered a string or a UUID type. + """ + + JSON_EXTRACT_SCALAR_SCALAR_ONLY = False + """ + Whether JSON_EXTRACT_SCALAR returns null if a non-scalar value is selected. + """ + + DEFAULT_FUNCTIONS_COLUMN_NAMES: t.Dict[ + t.Type[exp.Func], t.Union[str, t.Tuple[str, ...]] + ] = {} + """ + Maps function expressions to their default output column name(s). + + For example, in Postgres, generate_series function outputs a column named "generate_series" by default, + so we map the ExplodingGenerateSeries expression to "generate_series" string. + """ + + DEFAULT_NULL_TYPE = exp.DataType.Type.UNKNOWN + """ + The default type of NULL for producing the correct projection type. + + For example, in BigQuery the default type of the NULL value is INT64. + """ + + LEAST_GREATEST_IGNORES_NULLS = True + """ + Whether LEAST/GREATEST functions ignore NULL values, e.g: + - BigQuery, Snowflake, MySQL, Presto/Trino: LEAST(1, NULL, 2) -> NULL + - Spark, Postgres, DuckDB, TSQL: LEAST(1, NULL, 2) -> 1 + """ + + PRIORITIZE_NON_LITERAL_TYPES = False + """ + Whether to prioritize non-literal types over literals during type annotation. + """ + + # --- Autofilled --- + + tokenizer_class = Tokenizer + jsonpath_tokenizer_class = JSONPathTokenizer + parser_class = Parser + generator_class = Generator + + # A trie of the time_mapping keys + TIME_TRIE: t.Dict = {} + FORMAT_TRIE: t.Dict = {} + + INVERSE_TIME_MAPPING: t.Dict[str, str] = {} + INVERSE_TIME_TRIE: t.Dict = {} + INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} + INVERSE_FORMAT_TRIE: t.Dict = {} + + INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} + + ESCAPED_SEQUENCES: t.Dict[str, str] = {} + + # Delimiters for string literals and identifiers + QUOTE_START = "'" + QUOTE_END = "'" + IDENTIFIER_START = '"' + IDENTIFIER_END = '"' + + VALID_INTERVAL_UNITS: t.Set[str] = set() + + # Delimiters for bit, hex, byte and unicode literals + BIT_START: t.Optional[str] = None + BIT_END: t.Optional[str] = None + HEX_START: t.Optional[str] = None + HEX_END: t.Optional[str] = None + BYTE_START: t.Optional[str] = None + BYTE_END: t.Optional[str] = None + UNICODE_START: t.Optional[str] = None + UNICODE_END: t.Optional[str] = None + + DATE_PART_MAPPING = { + "Y": "YEAR", + "YY": "YEAR", + "YYY": "YEAR", + "YYYY": "YEAR", + "YR": "YEAR", + "YEARS": "YEAR", + "YRS": "YEAR", + "MM": "MONTH", + "MON": "MONTH", + "MONS": "MONTH", + "MONTHS": "MONTH", + "D": "DAY", + "DD": "DAY", + "DAYS": "DAY", + "DAYOFMONTH": "DAY", + "DAY OF WEEK": "DAYOFWEEK", + "WEEKDAY": "DAYOFWEEK", + "DOW": "DAYOFWEEK", + "DW": "DAYOFWEEK", + "WEEKDAY_ISO": "DAYOFWEEKISO", + "DOW_ISO": "DAYOFWEEKISO", + "DW_ISO": "DAYOFWEEKISO", + "DAYOFWEEK_ISO": "DAYOFWEEKISO", + "DAY OF YEAR": "DAYOFYEAR", + "DOY": "DAYOFYEAR", + "DY": "DAYOFYEAR", + "W": "WEEK", + "WK": "WEEK", + "WEEKOFYEAR": "WEEK", + "WOY": "WEEK", + "WY": "WEEK", + "WEEK_ISO": "WEEKISO", + "WEEKOFYEARISO": "WEEKISO", + "WEEKOFYEAR_ISO": "WEEKISO", + "Q": "QUARTER", + "QTR": "QUARTER", + "QTRS": "QUARTER", + "QUARTERS": "QUARTER", + "H": "HOUR", + "HH": "HOUR", + "HR": "HOUR", + "HOURS": "HOUR", + "HRS": "HOUR", + "M": "MINUTE", + "MI": "MINUTE", + "MIN": "MINUTE", + "MINUTES": "MINUTE", + "MINS": "MINUTE", + "S": "SECOND", + "SEC": "SECOND", + "SECONDS": "SECOND", + "SECS": "SECOND", + "MS": "MILLISECOND", + "MSEC": "MILLISECOND", + "MSECS": "MILLISECOND", + "MSECOND": "MILLISECOND", + "MSECONDS": "MILLISECOND", + "MILLISEC": "MILLISECOND", + "MILLISECS": "MILLISECOND", + "MILLISECON": "MILLISECOND", + "MILLISECONDS": "MILLISECOND", + "US": "MICROSECOND", + "USEC": "MICROSECOND", + "USECS": "MICROSECOND", + "MICROSEC": "MICROSECOND", + "MICROSECS": "MICROSECOND", + "USECOND": "MICROSECOND", + "USECONDS": "MICROSECOND", + "MICROSECONDS": "MICROSECOND", + "NS": "NANOSECOND", + "NSEC": "NANOSECOND", + "NANOSEC": "NANOSECOND", + "NSECOND": "NANOSECOND", + "NSECONDS": "NANOSECOND", + "NANOSECS": "NANOSECOND", + "EPOCH_SECOND": "EPOCH", + "EPOCH_SECONDS": "EPOCH", + "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", + "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", + "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", + "TZH": "TIMEZONE_HOUR", + "TZM": "TIMEZONE_MINUTE", + "DEC": "DECADE", + "DECS": "DECADE", + "DECADES": "DECADE", + "MIL": "MILLENNIUM", + "MILS": "MILLENNIUM", + "MILLENIA": "MILLENNIUM", + "C": "CENTURY", + "CENT": "CENTURY", + "CENTS": "CENTURY", + "CENTURIES": "CENTURY", + } + + # Specifies what types a given type can be coerced into + COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} + + # Specifies type inference & validation rules for expressions + EXPRESSION_METADATA = EXPRESSION_METADATA.copy() + + # Determines the supported Dialect instance settings + SUPPORTED_SETTINGS = { + "normalization_strategy", + "version", + } + + @classmethod + def get_or_raise(cls, dialect: DialectType) -> Dialect: + """ + Look up a dialect in the global dialect registry and return it if it exists. + + Args: + dialect: The target dialect. If this is a string, it can be optionally followed by + additional key-value pairs that are separated by commas and are used to specify + dialect settings, such as whether the dialect's identifiers are case-sensitive. + + Example: + >>> dialect = dialect_class = get_or_raise("duckdb") + >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") + + Returns: + The corresponding Dialect instance. + """ + + if not dialect: + return cls() + if isinstance(dialect, _Dialect): + return dialect() + if isinstance(dialect, Dialect): + return dialect + if isinstance(dialect, str): + try: + dialect_name, *kv_strings = dialect.split(",") + kv_pairs = (kv.split("=") for kv in kv_strings) + kwargs = {} + for pair in kv_pairs: + key = pair[0].strip() + value: t.Union[bool | str | None] = None + + if len(pair) == 1: + # Default initialize standalone settings to True + value = True + elif len(pair) == 2: + value = pair[1].strip() + + kwargs[key] = to_bool(value) + + except ValueError: + raise ValueError( + f"Invalid dialect format: '{dialect}'. " + "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." + ) + + result = cls.get(dialect_name.strip()) + if not result: + suggest_closest_match_and_fail( + "dialect", dialect_name, list(DIALECT_MODULE_NAMES) + ) + + assert result is not None + return result(**kwargs) + + raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") + + @classmethod + def format_time( + cls, expression: t.Optional[str | exp.Expression] + ) -> t.Optional[exp.Expression]: + """Converts a time format in this dialect to its equivalent Python `strftime` format.""" + if isinstance(expression, str): + return exp.Literal.string( + # the time formats are quoted + format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) + ) + + if expression and expression.is_string: + return exp.Literal.string( + format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE) + ) + + return expression + + def __init__(self, **kwargs) -> None: + parts = str(kwargs.pop("version", sys.maxsize)).split(".") + parts.extend(["0"] * (3 - len(parts))) + self.version = tuple(int(p) for p in parts[:3]) + + normalization_strategy = kwargs.pop("normalization_strategy", None) + if normalization_strategy is None: + self.normalization_strategy = self.NORMALIZATION_STRATEGY + else: + self.normalization_strategy = NormalizationStrategy( + normalization_strategy.upper() + ) + + self.settings = kwargs + + for unsupported_setting in kwargs.keys() - self.SUPPORTED_SETTINGS: + suggest_closest_match_and_fail( + "setting", unsupported_setting, self.SUPPORTED_SETTINGS + ) + + def __eq__(self, other: t.Any) -> bool: + # Does not currently take dialect state into account + return isinstance(self, other.__class__) + + def __hash__(self) -> int: + # Does not currently take dialect state into account + return hash(type(self)) + + def normalize_identifier(self, expression: E) -> E: + """ + Transforms an identifier in a way that resembles how it'd be resolved by this dialect. + + For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it + lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so + it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, + and so any normalization would be prohibited in order to avoid "breaking" the identifier. + + There are also dialects like Spark, which are case-insensitive even when quotes are + present, and dialects like MySQL, whose resolution rules match those employed by the + underlying operating system, for example they may always be case-sensitive in Linux. + + Finally, the normalization behavior of some engines can even be controlled through flags, + like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. + + SQLGlot aims to understand and handle all of these different behaviors gracefully, so + that it can analyze queries in the optimizer and successfully capture their semantics. + """ + if ( + isinstance(expression, exp.Identifier) + and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE + and ( + not expression.quoted + or self.normalization_strategy + in ( + NormalizationStrategy.CASE_INSENSITIVE, + NormalizationStrategy.CASE_INSENSITIVE_UPPERCASE, + ) + ) + ): + normalized = ( + expression.this.upper() + if self.normalization_strategy + in ( + NormalizationStrategy.UPPERCASE, + NormalizationStrategy.CASE_INSENSITIVE_UPPERCASE, + ) + else expression.this.lower() + ) + expression.set("this", normalized) + + return expression + + def case_sensitive(self, text: str) -> bool: + """Checks if text contains any case sensitive characters, based on the dialect's rules.""" + if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: + return False + + unsafe = ( + str.islower + if self.normalization_strategy is NormalizationStrategy.UPPERCASE + else str.isupper + ) + return any(unsafe(char) for char in text) + + def can_quote( + self, identifier: exp.Identifier, identify: str | bool = "safe" + ) -> bool: + """Checks if an identifier can be quoted + + Args: + identifier: The identifier to check. + identify: + `True`: Always returns `True` except for certain cases. + `"safe"`: Only returns `True` if the identifier is case-insensitive. + `"unsafe"`: Only returns `True` if the identifier is case-sensitive. + + Returns: + Whether the given text can be identified. + """ + if identifier.quoted: + return True + if not identify: + return False + if isinstance(identifier.parent, exp.Func): + return False + if identify is True: + return True + + is_safe = not self.case_sensitive(identifier.this) and bool( + exp.SAFE_IDENTIFIER_RE.match(identifier.this) + ) + + if identify == "safe": + return is_safe + if identify == "unsafe": + return not is_safe + + raise ValueError(f"Unexpected argument for identify: '{identify}'") + + def quote_identifier(self, expression: E, identify: bool = True) -> E: + """ + Adds quotes to a given expression if it is an identifier. + + Args: + expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. + identify: If set to `False`, the quotes will only be added if the identifier is deemed + "unsafe", with respect to its characters and this dialect's normalization strategy. + """ + if isinstance(expression, exp.Identifier): + expression.set("quoted", self.can_quote(expression, identify or "unsafe")) + return expression + + def to_json_path( + self, path: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if isinstance(path, exp.Literal): + path_text = path.name + if path.is_number: + path_text = f"[{path_text}]" + try: + return parse_json_path(path_text, self) + except ParseError as e: + if self.STRICT_JSON_PATH_SYNTAX and not path_text.lstrip().startswith( + ("lax", "strict") + ): + logger.warning(f"Invalid JSON path syntax. {str(e)}") + + return path + + def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: + return self.parser(**opts).parse(self.tokenize(sql), sql) + + def parse_into( + self, expression_type: exp.IntoType, sql: str, **opts + ) -> t.List[t.Optional[exp.Expression]]: + return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) + + def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: + return self.generator(**opts).generate(expression, copy=copy) + + def transpile(self, sql: str, **opts) -> t.List[str]: + return [ + self.generate(expression, copy=False, **opts) if expression else "" + for expression in self.parse(sql) + ] + + def tokenize(self, sql: str, **opts) -> t.List[Token]: + return self.tokenizer(**opts).tokenize(sql) + + def tokenizer(self, **opts) -> Tokenizer: + return self.tokenizer_class(**{"dialect": self, **opts}) + + def jsonpath_tokenizer(self, **opts) -> JSONPathTokenizer: + return self.jsonpath_tokenizer_class(**{"dialect": self, **opts}) + + def parser(self, **opts) -> Parser: + return self.parser_class(**{"dialect": self, **opts}) + + def generator(self, **opts) -> Generator: + return self.generator_class(**{"dialect": self, **opts}) + + def generate_values_aliases(self, expression: exp.Values) -> t.List[exp.Identifier]: + return [ + exp.to_identifier(f"_col_{i}") + for i, _ in enumerate(expression.expressions[0].expressions) + ] + + +DialectType = t.Union[str, Dialect, t.Type[Dialect], None] + + +def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: + return lambda self, expression: self.func(name, *flatten(expression.args.values())) + + +@unsupported_args("accuracy") +def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: + return self.func("APPROX_COUNT_DISTINCT", expression.this) + + +def if_sql( + name: str = "IF", false_value: t.Optional[exp.Expression | str] = None +) -> t.Callable[[Generator, exp.If], str]: + def _if_sql(self: Generator, expression: exp.If) -> str: + return self.func( + name, + expression.this, + expression.args.get("true"), + expression.args.get("false") or false_value, + ) + + return _if_sql + + +def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: + this = expression.this + if ( + self.JSON_TYPE_REQUIRED_FOR_EXTRACTION + and isinstance(this, exp.Literal) + and this.is_string + ): + this.replace(exp.cast(this, exp.DataType.Type.JSON)) + + return self.binary( + expression, "->" if isinstance(expression, exp.JSONExtract) else "->>" + ) + + +def inline_array_sql(self: Generator, expression: exp.Expression) -> str: + return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" + + +def inline_array_unless_query(self: Generator, expression: exp.Expression) -> str: + elem = seq_get(expression.expressions, 0) + if isinstance(elem, exp.Expression) and elem.find(exp.Query): + return self.func("ARRAY", elem) + return inline_array_sql(self, expression) + + +def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: + return self.like_sql( + exp.Like( + this=exp.Lower(this=expression.this), + expression=exp.Lower(this=expression.expression), + ) + ) + + +def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: + zone = self.sql(expression, "this") + return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" + + +def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: + if expression.args.get("recursive"): + self.unsupported("Recursive CTEs are unsupported") + expression.set("recursive", False) + return self.with_sql(expression) + + +def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: + self.unsupported("TABLESAMPLE unsupported") + return self.sql(expression.this) + + +def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: + self.unsupported("PIVOT unsupported") + return "" + + +def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: + return self.cast_sql(expression) + + +def no_comment_column_constraint_sql( + self: Generator, expression: exp.CommentColumnConstraint +) -> str: + self.unsupported("CommentColumnConstraint unsupported") + return "" + + +def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: + self.unsupported("MAP_FROM_ENTRIES unsupported") + return "" + + +def property_sql(self: Generator, expression: exp.Property) -> str: + return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" + + +def strposition_sql( + self: Generator, + expression: exp.StrPosition, + func_name: str = "STRPOS", + supports_position: bool = False, + supports_occurrence: bool = False, + use_ansi_position: bool = True, +) -> str: + string = expression.this + substr = expression.args.get("substr") + position = expression.args.get("position") + occurrence = expression.args.get("occurrence") + zero = exp.Literal.number(0) + one = exp.Literal.number(1) + + if supports_occurrence and occurrence and supports_position and not position: + position = one + + transpile_position = position and not supports_position + if transpile_position: + string = exp.Substring(this=string, start=position) + + if func_name == "POSITION" and use_ansi_position: + func = exp.Anonymous( + this=func_name, expressions=[exp.In(this=substr, field=string)] + ) + else: + args = ( + [substr, string] + if func_name in ("LOCATE", "CHARINDEX") + else [string, substr] + ) + if supports_position: + args.append(position) + if occurrence: + if supports_occurrence: + args.append(occurrence) + else: + self.unsupported( + f"{func_name} does not support the occurrence parameter." + ) + func = exp.Anonymous(this=func_name, expressions=args) + + if transpile_position: + func_with_offset = exp.Sub(this=func + position, expression=one) + func_wrapped = exp.If(this=func.eq(zero), true=zero, false=func_with_offset) + return self.sql(func_wrapped) + + return self.sql(func) + + +def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: + return f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" + + +def var_map_sql( + self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" +) -> str: + keys = expression.args.get("keys") + values = expression.args.get("values") + + if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): + self.unsupported("Cannot convert array columns into map.") + return self.func(map_func_name, keys, values) + + args = [] + for key, value in zip(keys.expressions, values.expressions): + args.append(self.sql(key)) + args.append(self.sql(value)) + + return self.func(map_func_name, *args) + + +def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str: + """ + Transpile MONTHS_BETWEEN to dialects that don't have native support. + + Snowflake's MONTHS_BETWEEN returns whole months + fractional part where: + - Fractional part = (DAY(date1) - DAY(date2)) / 31 + - Special case: If both dates are last day of month, fractional part = 0 + + Formula: DATEDIFF('month', date2, date1) + (DAY(date1) - DAY(date2)) / 31.0 + """ + date1 = expression.this + date2 = expression.expression + + # Cast to DATE to ensure consistent behavior + date1_cast = exp.cast(date1, exp.DataType.Type.DATE, copy=False) + date2_cast = exp.cast(date2, exp.DataType.Type.DATE, copy=False) + + # Whole months: DATEDIFF('month', date2, date1) + whole_months = exp.DateDiff( + this=date1_cast, expression=date2_cast, unit=exp.var("month") + ) + + # Day components + day1 = exp.Day(this=date1_cast.copy()) + day2 = exp.Day(this=date2_cast.copy()) + + # Last day of month components + last_day_of_month1 = exp.LastDay(this=date1_cast.copy()) + last_day_of_month2 = exp.LastDay(this=date2_cast.copy()) + + day_of_last_day1 = exp.Day(this=last_day_of_month1) + day_of_last_day2 = exp.Day(this=last_day_of_month2) + + # Check if both are last day of month + last_day1 = exp.EQ(this=day1.copy(), expression=day_of_last_day1) + last_day2 = exp.EQ(this=day2.copy(), expression=day_of_last_day2) + both_last_day = exp.And(this=last_day1, expression=last_day2) + + # Fractional part: (DAY(date1) - DAY(date2)) / 31.0 + fractional = exp.Div( + this=exp.Paren(this=exp.Sub(this=day1.copy(), expression=day2.copy())), + expression=exp.Literal.number("31.0"), + ) + + # If both are last day of month, fractional = 0, else calculate fractional + fractional_with_check = exp.If( + this=both_last_day, true=exp.Literal.number("0"), false=fractional + ) + + # Final result: whole_months + fractional + result = exp.Add(this=whole_months, expression=fractional_with_check) + + return self.sql(result) + + +def build_formatted_time( + exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None +) -> t.Callable[[t.List], E]: + """Helper used for time expressions. + + Args: + exp_class: the expression class to instantiate. + dialect: target sql dialect. + default: the default format, True being time. + + Returns: + A callable that can be used to return the appropriately formatted time expression. + """ + + def _builder(args: t.List): + return exp_class( + this=seq_get(args, 0), + format=Dialect[dialect].format_time( + seq_get(args, 1) + or ( + Dialect[dialect].TIME_FORMAT if default is True else default or None + ) + ), + ) + + return _builder + + +def time_format( + dialect: DialectType = None, +) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: + def _time_format( + self: Generator, expression: exp.UnixToStr | exp.StrToUnix + ) -> t.Optional[str]: + """ + Returns the time format for a given expression, unless it's equivalent + to the default time format of the dialect of interest. + """ + time_format = self.format_time(expression) + return ( + time_format + if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT + else None + ) + + return _time_format + + +def build_date_delta( + exp_class: t.Type[E], + unit_mapping: t.Optional[t.Dict[str, str]] = None, + default_unit: t.Optional[str] = "DAY", + supports_timezone: bool = False, +) -> t.Callable[[t.List], E]: + def _builder(args: t.List) -> E: + unit_based = len(args) >= 3 + has_timezone = len(args) == 4 + this = args[2] if unit_based else seq_get(args, 0) + unit = None + if unit_based or default_unit: + unit = args[0] if unit_based else exp.Literal.string(default_unit) + unit = ( + exp.var(unit_mapping.get(unit.name.lower(), unit.name)) + if unit_mapping + else unit + ) + expression = exp_class(this=this, expression=seq_get(args, 1), unit=unit) + if supports_timezone and has_timezone: + expression.set("zone", args[-1]) + return expression + + return _builder + + +def build_date_delta_with_interval( + expression_class: t.Type[E], +) -> t.Callable[[t.List], t.Optional[E]]: + def _builder(args: t.List) -> t.Optional[E]: + if len(args) < 2: + return None + + interval = args[1] + + if not isinstance(interval, exp.Interval): + raise ParseError(f"INTERVAL expression expected but got '{interval}'") + + return expression_class( + this=args[0], expression=interval.this, unit=unit_to_str(interval) + ) + + return _builder + + +def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: + unit = seq_get(args, 0) + this = seq_get(args, 1) + + if isinstance(this, exp.Cast) and this.is_type("date"): + return exp.DateTrunc(unit=unit, this=this) + return exp.TimestampTrunc(this=this, unit=unit) + + +def date_add_interval_sql( + data_type: str, kind: str +) -> t.Callable[[Generator, exp.Expression], str]: + def func(self: Generator, expression: exp.Expression) -> str: + this = self.sql(expression, "this") + interval = exp.Interval( + this=expression.expression, unit=unit_to_var(expression) + ) + return f"{data_type}_{kind}({this}, {self.sql(interval)})" + + return func + + +def timestamptrunc_sql( + func: str = "DATE_TRUNC", zone: bool = False +) -> t.Callable[[Generator, exp.TimestampTrunc], str]: + def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: + args = [unit_to_str(expression), expression.this] + if zone: + args.append(expression.args.get("zone")) + return self.func(func, *args) + + return _timestamptrunc_sql + + +def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: + zone = expression.args.get("zone") + if not zone: + from sqlglot.optimizer.annotate_types import annotate_types + + target_type = ( + annotate_types(expression, dialect=self.dialect).type + or exp.DataType.Type.TIMESTAMP + ) + return self.sql(exp.cast(expression.this, target_type)) + if zone.name.lower() in TIMEZONES: + return self.sql( + exp.AtTimeZone( + this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), + zone=zone, + ) + ) + return self.func("TIMESTAMP", expression.this, zone) + + +def no_time_sql(self: Generator, expression: exp.Time) -> str: + # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ AT TIME ZONE AS TIME) + this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) + expr = exp.cast( + exp.AtTimeZone(this=this, zone=expression.args.get("zone")), + exp.DataType.Type.TIME, + ) + return self.sql(expr) + + +def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: + this = expression.this + expr = expression.expression + + if expr.name.lower() in TIMEZONES: + # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ AT TIME ZONE AS TIMESTAMP) + this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) + this = exp.cast( + exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP + ) + return self.sql(this) + + this = exp.cast(this, exp.DataType.Type.DATE) + expr = exp.cast(expr, exp.DataType.Type.TIME) + + return self.sql( + exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP) + ) + + +def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: + return self.sql( + exp.Substring( + this=expression.this, + start=exp.Literal.number(1), + length=expression.expression, + ) + ) + + +def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: + return self.sql( + exp.Substring( + this=expression.this, + start=exp.Length(this=expression.this) + - exp.paren(expression.expression - 1), + ) + ) + + +def timestrtotime_sql( + self: Generator, + expression: exp.TimeStrToTime, + include_precision: bool = False, +) -> str: + datatype = exp.DataType.build( + exp.DataType.Type.TIMESTAMPTZ + if expression.args.get("zone") + else exp.DataType.Type.TIMESTAMP + ) + + if isinstance(expression.this, exp.Literal) and include_precision: + precision = subsecond_precision(expression.this.name) + if precision > 0: + datatype = exp.DataType.build( + datatype.this, + expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))], + ) + + return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) + + +def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: + return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) + + +# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 +def encode_decode_sql( + self: Generator, expression: exp.Expression, name: str, replace: bool = True +) -> str: + charset = expression.args.get("charset") + if charset and charset.name.lower() != "utf-8": + self.unsupported(f"Expected utf-8 character set, got {charset}.") + + return self.func( + name, expression.this, expression.args.get("replace") if replace else None + ) + + +def min_or_least(self: Generator, expression: exp.Min) -> str: + name = "LEAST" if expression.expressions else "MIN" + return rename_func(name)(self, expression) + + +def max_or_greatest(self: Generator, expression: exp.Max) -> str: + name = "GREATEST" if expression.expressions else "MAX" + return rename_func(name)(self, expression) + + +def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: + cond = expression.this + + if isinstance(expression.this, exp.Distinct): + cond = expression.this.expressions[0] + self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") + + return self.func("sum", exp.func("if", cond, 1, 0)) + + +def trim_sql(self: Generator, expression: exp.Trim, default_trim_type: str = "") -> str: + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") or default_trim_type + remove_chars = self.sql(expression, "expression") + collation = self.sql(expression, "collation") + + # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific + if not remove_chars: + return self.trim_sql(expression) + + trim_type = f"{trim_type} " if trim_type else "" + remove_chars = f"{remove_chars} " if remove_chars else "" + from_part = "FROM " if trim_type or remove_chars else "" + collation = f" COLLATE {collation}" if collation else "" + return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" + + +def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: + return self.func("STRPTIME", expression.this, self.format_time(expression)) + + +def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: + return self.sql( + reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions) + ) + + +def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: + delim, *rest_args = expression.expressions + return self.sql( + reduce( + lambda x, y: exp.DPipe( + this=x, expression=exp.DPipe(this=delim, expression=y) + ), + rest_args, + ) + ) + + +@unsupported_args("position", "occurrence", "parameters") +def regexp_extract_sql( + self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll +) -> str: + group = expression.args.get("group") + + # Do not render group if it's the default value for this dialect + if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): + group = None + + return self.func( + expression.sql_name(), expression.this, expression.expression, group + ) + + +@unsupported_args("position", "occurrence", "modifiers") +def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: + return self.func( + "REGEXP_REPLACE", + expression.this, + expression.expression, + expression.args["replacement"], + ) + + +def pivot_column_names( + aggregations: t.List[exp.Expression], dialect: DialectType +) -> t.List[str]: + names = [] + for agg in aggregations: + if isinstance(agg, exp.Alias): + names.append(agg.alias) + else: + """ + This case corresponds to aggregations without aliases being used as suffixes + (e.g. col_avg(foo)). We need to unquote identifiers because they're going to + be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. + Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). + """ + agg_all_unquoted = agg.transform( + lambda node: ( + exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) + ) + names.append( + agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower") + ) + + return names + + +def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: + return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) + + +# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects +def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: + return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) + + +def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: + return self.func("MAX", expression.this) + + +def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: + a = self.sql(expression.left) + b = self.sql(expression.right) + return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" + + +def is_parse_json(expression: exp.Expression) -> bool: + return isinstance(expression, exp.ParseJSON) or ( + isinstance(expression, exp.Cast) and expression.is_type("json") + ) + + +def isnull_to_is_null(args: t.List) -> exp.Expression: + return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) + + +def generatedasidentitycolumnconstraint_sql( + self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint +) -> str: + start = self.sql(expression, "start") or "1" + increment = self.sql(expression, "increment") or "1" + return f"IDENTITY({start}, {increment})" + + +def arg_max_or_min_no_count( + name: str, +) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: + @unsupported_args("count") + def _arg_max_or_min_sql( + self: Generator, expression: exp.ArgMax | exp.ArgMin + ) -> str: + return self.func(name, expression.this, expression.expression) + + return _arg_max_or_min_sql + + +def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: + this = expression.this.copy() + + return_type = expression.return_type + if return_type.is_type(exp.DataType.Type.DATE): + # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we + # can truncate timestamp strings, because some dialects can't cast them to DATE + this = exp.cast(this, exp.DataType.Type.TIMESTAMP) + + expression.this.replace(exp.cast(this, return_type)) + return expression + + +def date_delta_sql( + name: str, cast: bool = False +) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: + def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: + if cast and isinstance(expression, exp.TsOrDsAdd): + expression = ts_or_ds_add_cast(expression) + + return self.func( + name, + unit_to_var(expression), + expression.expression, + expression.this, + ) + + return _delta_sql + + +def date_delta_to_binary_interval_op( + cast: bool = True, +) -> t.Callable[[Generator, DATETIME_DELTA], str]: + def date_delta_to_binary_interval_op_sql( + self: Generator, expression: DATETIME_DELTA + ) -> str: + this = expression.this + unit = unit_to_var(expression) + op = "+" if isinstance(expression, DATETIME_ADD) else "-" + + to_type: t.Optional[exp.DATA_TYPE] = None + if cast: + if isinstance(expression, exp.TsOrDsAdd): + to_type = expression.return_type + elif this.is_string: + # Cast string literals (i.e function parameters) to the appropriate type for +/- interval to work + to_type = ( + exp.DataType.Type.DATETIME + if isinstance(expression, (exp.DatetimeAdd, exp.DatetimeSub)) + else exp.DataType.Type.DATE + ) + + this = exp.cast(this, to_type) if to_type else this + + expr = expression.expression + interval = ( + expr + if isinstance(expr, exp.Interval) + else exp.Interval(this=expr, unit=unit) + ) + + return f"{self.sql(this)} {op} {self.sql(interval)}" + + return date_delta_to_binary_interval_op_sql + + +def unit_to_str( + expression: exp.Expression, default: str = "DAY" +) -> t.Optional[exp.Expression]: + unit = expression.args.get("unit") + if not unit: + return exp.Literal.string(default) if default else None + + if isinstance(unit, exp.Placeholder) or type(unit) not in (exp.Var, exp.Literal): + return unit + + return exp.Literal.string(unit.name) + + +def unit_to_var( + expression: exp.Expression, default: str = "DAY" +) -> t.Optional[exp.Expression]: + unit = expression.args.get("unit") + + if isinstance(unit, (exp.Var, exp.Placeholder, exp.WeekStart, exp.Column)): + return unit + + value = unit.name if unit else default + return exp.Var(this=value) if value else None + + +@t.overload +def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: + pass + + +@t.overload +def map_date_part( + part: t.Optional[exp.Expression], dialect: DialectType = Dialect +) -> t.Optional[exp.Expression]: + pass + + +def map_date_part(part, dialect: DialectType = Dialect): + mapped = ( + Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) + if part and not (isinstance(part, exp.Column) and len(part.parts) != 1) + else None + ) + if mapped: + return exp.Literal.string(mapped) if part.is_string else exp.var(mapped) + + return part + + +def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: + trunc_curr_date = exp.func("date_trunc", "month", expression.this) + plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") + minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") + + return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) + + +def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: + """Remove table refs from columns in when statements.""" + alias = expression.this.args.get("alias") + + def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: + return ( + self.dialect.normalize_identifier(identifier).name if identifier else None + ) + + targets = {normalize(expression.this.this)} + + if alias: + targets.add(normalize(alias.this)) + + for when in expression.args["whens"].expressions: + # only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED + # they are still valid in the , the right hand side of each UPDATE and the VALUES part + # (not the column list) of the INSERT + then: exp.Insert | exp.Update | None = when.args.get("then") + if then: + if isinstance(then, exp.Update): + for equals in then.find_all(exp.EQ): + equal_lhs = equals.this + if ( + isinstance(equal_lhs, exp.Column) + and normalize(equal_lhs.args.get("table")) in targets + ): + equal_lhs.replace(exp.column(equal_lhs.this)) + if isinstance(then, exp.Insert): + column_list = then.this + if isinstance(column_list, exp.Tuple): + for column in column_list.expressions: + if normalize(column.args.get("table")) in targets: + column.replace(exp.column(column.this)) + + return self.merge_sql(expression) + + +def build_json_extract_path( + expr_type: t.Type[F], + zero_based_indexing: bool = True, + arrow_req_json_type: bool = False, + json_type: t.Optional[str] = None, +) -> t.Callable[[t.List], F]: + def _builder(args: t.List) -> F: + segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] + for arg in args[1:]: + if not isinstance(arg, exp.Literal): + # We use the fallback parser because we can't really transpile non-literals safely + return expr_type.from_arg_list(args) + + text = arg.name + if is_int(text) and (not arrow_req_json_type or not arg.is_string): + index = int(text) + segments.append( + exp.JSONPathSubscript( + this=index if zero_based_indexing else index - 1 + ) + ) + else: + segments.append(exp.JSONPathKey(this=text)) + + # This is done to avoid failing in the expression validator due to the arg count + del args[2:] + kwargs = { + "this": seq_get(args, 0), + "expression": exp.JSONPath(expressions=segments), + } + + is_jsonb = issubclass(expr_type, (exp.JSONBExtract, exp.JSONBExtractScalar)) + if not is_jsonb: + kwargs["only_json_types"] = arrow_req_json_type + + if json_type is not None: + kwargs["json_type"] = json_type + + return expr_type(**kwargs) + + return _builder + + +def json_extract_segments( + name: str, quoted_index: bool = True, op: t.Optional[str] = None +) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: + def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: + path = expression.expression + if not isinstance(path, exp.JSONPath): + return rename_func(name)(self, expression) + + escape = path.args.get("escape") + + segments = [] + for segment in path.expressions: + path = self.sql(segment) + if path: + if isinstance(segment, exp.JSONPathPart) and ( + quoted_index or not isinstance(segment, exp.JSONPathSubscript) + ): + if escape: + path = self.escape_str(path) + + path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" + + segments.append(path) + + if op: + return f" {op} ".join([self.sql(expression.this), *segments]) + return self.func(name, expression.this, *segments) + + return _json_extract_segments + + +def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: + if isinstance(expression.this, exp.JSONPathWildcard): + self.unsupported("Unsupported wildcard in JSONPathKey expression") + + return expression.name + + +def filter_array_using_unnest( + self: Generator, expression: exp.ArrayFilter | exp.ArrayRemove +) -> str: + cond = expression.expression + if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: + alias = cond.expressions[0] + cond = cond.this + elif isinstance(cond, exp.Predicate): + alias = "_u" + elif isinstance(expression, exp.ArrayRemove): + alias = "_u" + cond = exp.NEQ(this=alias, expression=expression.expression) + else: + self.unsupported("Unsupported filter condition") + return "" + + unnest = exp.Unnest(expressions=[expression.this]) + filtered = ( + exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) + ) + return self.sql(exp.Array(expressions=[filtered])) + + +def remove_from_array_using_filter(self: Generator, expression: exp.ArrayRemove) -> str: + lambda_id = exp.to_identifier("_u") + cond = exp.NEQ(this=lambda_id, expression=expression.expression) + return self.sql( + exp.ArrayFilter( + this=expression.this, + expression=exp.Lambda(this=cond, expressions=[lambda_id]), + ) + ) + + +def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: + return self.func( + "TO_NUMBER", + expression.this, + expression.args.get("format"), + expression.args.get("nlsparam"), + ) + + +def build_default_decimal_type( + precision: t.Optional[int] = None, scale: t.Optional[int] = None +) -> t.Callable[[exp.DataType], exp.DataType]: + def _builder(dtype: exp.DataType) -> exp.DataType: + if dtype.expressions or precision is None: + return dtype + + params = f"{precision}{f', {scale}' if scale is not None else ''}" + return exp.DataType.build(f"DECIMAL({params})") + + return _builder + + +def build_timestamp_from_parts(args: t.List) -> exp.Func: + if len(args) == 2: + # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, + # so we parse this into Anonymous for now instead of introducing complexity + return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) + + return exp.TimestampFromParts.from_arg_list(args) + + +def sha256_sql(self: Generator, expression: exp.SHA2) -> str: + return self.func(f"SHA{expression.text('length') or '256'}", expression.this) + + +def sha2_digest_sql(self: Generator, expression: exp.SHA2Digest) -> str: + return self.func(f"SHA{expression.text('length') or '256'}", expression.this) + + +def sequence_sql( + self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray +) -> str: + start = expression.args.get("start") + end = expression.args.get("end") + step = expression.args.get("step") + + if isinstance(start, exp.Cast): + target_type = start.to + elif isinstance(end, exp.Cast): + target_type = end.to + else: + target_type = None + + if start and end: + if target_type and target_type.is_type("date", "timestamp"): + if isinstance(start, exp.Cast) and target_type is start.to: + end = exp.cast(end, target_type) + else: + start = exp.cast(start, target_type) + + if expression.args.get("is_end_exclusive"): + step_value = step or exp.Literal.number(1) + end = exp.paren(exp.Sub(this=end, expression=step_value), copy=False) + + sequence_call = exp.Anonymous( + this="SEQUENCE", expressions=[e for e in (start, end, step) if e] + ) + zero = exp.Literal.number(0) + should_return_empty = exp.or_( + exp.EQ(this=step_value.copy(), expression=zero.copy()), + exp.and_( + exp.GT(this=step_value.copy(), expression=zero.copy()), + exp.GTE(this=start.copy(), expression=end.copy()), + ), + exp.and_( + exp.LT(this=step_value.copy(), expression=zero.copy()), + exp.LTE(this=start.copy(), expression=end.copy()), + ), + ) + empty_array_or_sequence = exp.If( + this=should_return_empty, + true=exp.Array(expressions=[]), + false=sequence_call, + ) + return self.sql(self._simplify_unless_literal(empty_array_or_sequence)) + + return self.func("SEQUENCE", start, end, step) + + +def build_like( + expr_type: t.Type[E], not_like: bool = False +) -> t.Callable[[t.List], exp.Expression]: + def _builder(args: t.List) -> exp.Expression: + like_expr: exp.Expression = expr_type( + this=seq_get(args, 0), expression=seq_get(args, 1) + ) + + if escape := seq_get(args, 2): + like_expr = exp.Escape(this=like_expr, expression=escape) + + if not_like: + like_expr = exp.Not(this=like_expr) + + return like_expr + + return _builder + + +def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: + def _builder(args: t.List, dialect: Dialect) -> E: + # The "position" argument specifies the index of the string character to start matching from. + # `null_if_pos_overflow` reflects the dialect's behavior when position is greater than the string + # length. If true, returns NULL. If false, returns an empty string. `null_if_pos_overflow` is + # only needed for exp.RegexpExtract - exp.RegexpExtractAll always returns an empty array if + # position overflows. + return expr_type( + this=seq_get(args, 0), + expression=seq_get(args, 1), + group=seq_get(args, 2) + or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), + parameters=seq_get(args, 3), + **( + { + "null_if_pos_overflow": dialect.REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL + } + if expr_type is exp.RegexpExtract + else {} + ), + ) + + return _builder + + +def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: + if isinstance(expression.this, exp.Explode): + return self.sql( + exp.Join( + this=exp.Unnest( + expressions=[expression.this.this], + alias=expression.args.get("alias"), + offset=isinstance(expression.this, exp.Posexplode), + ), + kind="cross", + ) + ) + return self.lateral_sql(expression) + + +def timestampdiff_sql( + self: Generator, expression: exp.DatetimeDiff | exp.TimestampDiff +) -> str: + return self.func( + "TIMESTAMPDIFF", expression.unit, expression.expression, expression.this + ) + + +def no_make_interval_sql( + self: Generator, expression: exp.MakeInterval, sep: str = ", " +) -> str: + args = [] + for unit, value in expression.args.items(): + if isinstance(value, exp.Kwarg): + value = value.expression + + args.append(f"{value} {unit}") + + return f"INTERVAL '{self.format_args(*args, sep=sep)}'" + + +def length_or_char_length_sql(self: Generator, expression: exp.Length) -> str: + length_func = "LENGTH" if expression.args.get("binary") else "CHAR_LENGTH" + return self.func(length_func, expression.this) + + +def groupconcat_sql( + self: Generator, + expression: exp.GroupConcat, + func_name="LISTAGG", + sep: t.Optional[str] = ",", + within_group: bool = True, + on_overflow: bool = False, +) -> str: + this = expression.this + separator = self.sql( + expression.args.get("separator") or (exp.Literal.string(sep) if sep else None) + ) + + on_overflow_sql = self.sql(expression, "on_overflow") + on_overflow_sql = ( + f" ON OVERFLOW {on_overflow_sql}" if (on_overflow and on_overflow_sql) else "" + ) + + if isinstance(this, exp.Limit) and this.this: + limit = this + this = limit.this.pop() + else: + limit = None + + order = this.find(exp.Order) + + if order and order.this: + this = order.this.pop() + + args = self.format_args( + this, f"{separator}{on_overflow_sql}" if separator or on_overflow_sql else None + ) + + listagg: exp.Expression = exp.Anonymous(this=func_name, expressions=[args]) + + modifiers = self.sql(limit) + + if order: + if within_group: + listagg = exp.WithinGroup(this=listagg, expression=order) + else: + modifiers = f"{self.sql(order)}{modifiers}" + + if modifiers: + listagg.set("expressions", [f"{args}{modifiers}"]) + + return self.sql(listagg) + + +def build_timetostr_or_tochar( + args: t.List, dialect: DialectType +) -> exp.TimeToStr | exp.ToChar: + if len(args) == 2: + this = args[0] + if not this.type: + from sqlglot.optimizer.annotate_types import annotate_types + + annotate_types(this, dialect=dialect) + + if this.is_type(*exp.DataType.TEMPORAL_TYPES): + dialect_name = dialect.__class__.__name__.lower() + return build_formatted_time(exp.TimeToStr, dialect_name, default=True)(args) + + return exp.ToChar.from_arg_list(args) + + +def build_replace_with_optional_replacement(args: t.List) -> exp.Replace: + return exp.Replace( + this=seq_get(args, 0), + expression=seq_get(args, 1), + replacement=seq_get(args, 2) or exp.Literal.string(""), + ) + + +def regexp_replace_global_modifier( + expression: exp.RegexpReplace, +) -> exp.Expression | None: + modifiers = expression.args.get("modifiers") + single_replace = expression.args.get("single_replace") + occurrence = expression.args.get("occurrence") + + if not single_replace and ( + not occurrence or (occurrence.is_int and occurrence.to_py() == 0) + ): + if not modifiers or modifiers.is_string: + # Append 'g' to the modifiers if they are not provided since + # the semantics of REGEXP_REPLACE from the input dialect + # is to replace all occurrences of the pattern. + value = "" if not modifiers else modifiers.name + modifiers = exp.Literal.string(value + "g") + + return modifiers diff --git a/third_party/bigframes_vendored/sqlglot/diff.py b/third_party/bigframes_vendored/sqlglot/diff.py new file mode 100644 index 0000000000..0e15f7c1ad --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/diff.py @@ -0,0 +1,511 @@ +""" +.. include:: ../posts/sql_diff.md + +---- +""" + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from heapq import heappop, heappush +from itertools import chain +import typing as t + +from bigframes_vendored.sqlglot import Dialect +from bigframes_vendored.sqlglot import expressions as exp +from bigframes_vendored.sqlglot.helper import seq_get + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + + +@dataclass(frozen=True) +class Insert: + """Indicates that a new node has been inserted""" + + expression: exp.Expression + + +@dataclass(frozen=True) +class Remove: + """Indicates that an existing node has been removed""" + + expression: exp.Expression + + +@dataclass(frozen=True) +class Move: + """Indicates that an existing node's position within the tree has changed""" + + source: exp.Expression + target: exp.Expression + + +@dataclass(frozen=True) +class Update: + """Indicates that an existing node has been updated""" + + source: exp.Expression + target: exp.Expression + + +@dataclass(frozen=True) +class Keep: + """Indicates that an existing node hasn't been changed""" + + source: exp.Expression + target: exp.Expression + + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import T + + Edit = t.Union[Insert, Remove, Move, Update, Keep] + + +def diff( + source: exp.Expression, + target: exp.Expression, + matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, + delta_only: bool = False, + **kwargs: t.Any, +) -> t.List[Edit]: + """ + Returns the list of changes between the source and the target expressions. + + Examples: + >>> diff(parse_one("a + b"), parse_one("a + c")) + [ + Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))), + Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))), + Keep( + source=(ADD this: ...), + target=(ADD this: ...) + ), + Keep( + source=(COLUMN this: (IDENTIFIER this: a, quoted: False)), + target=(COLUMN this: (IDENTIFIER this: a, quoted: False)) + ), + ] + + Args: + source: the source expression. + target: the target expression against which the diff should be calculated. + matchings: the list of pre-matched node pairs which is used to help the algorithm's + heuristics produce better results for subtrees that are known by a caller to be matching. + Note: expression references in this list must refer to the same node objects that are + referenced in the source / target trees. + delta_only: excludes all `Keep` nodes from the diff. + kwargs: additional arguments to pass to the ChangeDistiller instance. + + Returns: + the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the + target expression trees. This list represents a sequence of steps needed to transform the source + expression tree into the target one. + """ + matchings = matchings or [] + + def compute_node_mappings( + old_nodes: tuple[exp.Expression, ...], new_nodes: tuple[exp.Expression, ...] + ) -> t.Dict[int, exp.Expression]: + node_mapping = {} + for old_node, new_node in zip(reversed(old_nodes), reversed(new_nodes)): + new_node._hash = hash(new_node) + node_mapping[id(old_node)] = new_node + + return node_mapping + + # if the source and target have any shared objects, that means there's an issue with the ast + # the algorithm won't work because the parent / hierarchies will be inaccurate + source_nodes = tuple(source.walk()) + target_nodes = tuple(target.walk()) + source_ids = {id(n) for n in source_nodes} + target_ids = {id(n) for n in target_nodes} + + copy = ( + len(source_nodes) != len(source_ids) + or len(target_nodes) != len(target_ids) + or source_ids & target_ids + ) + + source_copy = source.copy() if copy else source + target_copy = target.copy() if copy else target + + try: + # We cache the hash of each new node here to speed up equality comparisons. If the input + # trees aren't copied, these hashes will be evicted before returning the edit script. + if copy and matchings: + source_mapping = compute_node_mappings( + source_nodes, tuple(source_copy.walk()) + ) + target_mapping = compute_node_mappings( + target_nodes, tuple(target_copy.walk()) + ) + matchings = [ + (source_mapping[id(s)], target_mapping[id(t)]) for s, t in matchings + ] + else: + for node in chain(reversed(source_nodes), reversed(target_nodes)): + node._hash = hash(node) + + edit_script = ChangeDistiller(**kwargs).diff( + source_copy, + target_copy, + matchings=matchings, + delta_only=delta_only, + ) + finally: + if not copy: + for node in chain(source_nodes, target_nodes): + node._hash = None + + return edit_script + + +# The expression types for which Update edits are allowed. +UPDATABLE_EXPRESSION_TYPES = ( + exp.Alias, + exp.Boolean, + exp.Column, + exp.DataType, + exp.Lambda, + exp.Literal, + exp.Table, + exp.Window, +) + +IGNORED_LEAF_EXPRESSION_TYPES = (exp.Identifier,) + + +class ChangeDistiller: + """ + The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in + their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by + Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. + """ + + def __init__( + self, f: float = 0.6, t: float = 0.6, dialect: DialectType = None + ) -> None: + self.f = f + self.t = t + self._sql_generator = Dialect.get_or_raise(dialect).generator() + + def diff( + self, + source: exp.Expression, + target: exp.Expression, + matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, + delta_only: bool = False, + ) -> t.List[Edit]: + matchings = matchings or [] + pre_matched_nodes = {id(s): id(t) for s, t in matchings} + + self._source = source + self._target = target + self._source_index = { + id(n): n + for n in self._source.bfs() + if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) + } + self._target_index = { + id(n): n + for n in self._target.bfs() + if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) + } + self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes) + self._unmatched_target_nodes = set(self._target_index) - set( + pre_matched_nodes.values() + ) + self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} + + matching_set = self._compute_matching_set() | set(pre_matched_nodes.items()) + return self._generate_edit_script(dict(matching_set), delta_only) + + def _generate_edit_script( + self, matchings: t.Dict[int, int], delta_only: bool + ) -> t.List[Edit]: + edit_script: t.List[Edit] = [] + for removed_node_id in self._unmatched_source_nodes: + edit_script.append(Remove(self._source_index[removed_node_id])) + for inserted_node_id in self._unmatched_target_nodes: + edit_script.append(Insert(self._target_index[inserted_node_id])) + for kept_source_node_id, kept_target_node_id in matchings.items(): + source_node = self._source_index[kept_source_node_id] + target_node = self._target_index[kept_target_node_id] + + identical_nodes = source_node == target_node + + if ( + not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES) + or identical_nodes + ): + if identical_nodes: + source_parent = source_node.parent + target_parent = target_node.parent + + if ( + (source_parent and not target_parent) + or (not source_parent and target_parent) + or ( + source_parent + and target_parent + and matchings.get(id(source_parent)) != id(target_parent) + ) + ): + edit_script.append(Move(source=source_node, target=target_node)) + else: + edit_script.extend( + self._generate_move_edits(source_node, target_node, matchings) + ) + + source_non_expression_leaves = dict( + _get_non_expression_leaves(source_node) + ) + target_non_expression_leaves = dict( + _get_non_expression_leaves(target_node) + ) + + if source_non_expression_leaves != target_non_expression_leaves: + edit_script.append(Update(source_node, target_node)) + elif not delta_only: + edit_script.append(Keep(source_node, target_node)) + else: + edit_script.append(Update(source_node, target_node)) + + return edit_script + + def _generate_move_edits( + self, + source: exp.Expression, + target: exp.Expression, + matchings: t.Dict[int, int], + ) -> t.List[Move]: + source_args = [id(e) for e in _expression_only_args(source)] + target_args = [id(e) for e in _expression_only_args(target)] + + args_lcs = set( + _lcs( + source_args, + target_args, + lambda ll, r: matchings.get(t.cast(int, ll)) == r, + ) + ) + + move_edits = [] + for a in source_args: + if a not in args_lcs and a not in self._unmatched_source_nodes: + move_edits.append( + Move( + source=self._source_index[a], + target=self._target_index[matchings[a]], + ) + ) + + return move_edits + + def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]: + leaves_matching_set = self._compute_leaf_matching_set() + matching_set = leaves_matching_set.copy() + + ordered_unmatched_source_nodes = { + id(n): None + for n in self._source.bfs() + if id(n) in self._unmatched_source_nodes + } + ordered_unmatched_target_nodes = { + id(n): None + for n in self._target.bfs() + if id(n) in self._unmatched_target_nodes + } + + for source_node_id in ordered_unmatched_source_nodes: + for target_node_id in ordered_unmatched_target_nodes: + source_node = self._source_index[source_node_id] + target_node = self._target_index[target_node_id] + if _is_same_type(source_node, target_node): + source_leaf_ids = { + id(ll) for ll in _get_expression_leaves(source_node) + } + target_leaf_ids = { + id(ll) for ll in _get_expression_leaves(target_node) + } + + max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) + if max_leaves_num: + common_leaves_num = sum( + 1 if s in source_leaf_ids and t in target_leaf_ids else 0 + for s, t in leaves_matching_set + ) + leaf_similarity_score = common_leaves_num / max_leaves_num + else: + leaf_similarity_score = 0.0 + + adjusted_t = ( + self.t + if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 + else 0.4 + ) + + if leaf_similarity_score >= 0.8 or ( + leaf_similarity_score >= adjusted_t + and self._dice_coefficient(source_node, target_node) >= self.f + ): + matching_set.add((source_node_id, target_node_id)) + self._unmatched_source_nodes.remove(source_node_id) + self._unmatched_target_nodes.remove(target_node_id) + ordered_unmatched_target_nodes.pop(target_node_id, None) + break + + return matching_set + + def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]: + candidate_matchings: t.List[ + t.Tuple[float, int, int, exp.Expression, exp.Expression] + ] = [] + source_expression_leaves = list(_get_expression_leaves(self._source)) + target_expression_leaves = list(_get_expression_leaves(self._target)) + for source_leaf in source_expression_leaves: + for target_leaf in target_expression_leaves: + if _is_same_type(source_leaf, target_leaf): + similarity_score = self._dice_coefficient(source_leaf, target_leaf) + if similarity_score >= self.f: + heappush( + candidate_matchings, + ( + -similarity_score, + -_parent_similarity_score(source_leaf, target_leaf), + len(candidate_matchings), + source_leaf, + target_leaf, + ), + ) + + # Pick best matchings based on the highest score + matching_set = set() + while candidate_matchings: + _, _, _, source_leaf, target_leaf = heappop(candidate_matchings) + if ( + id(source_leaf) in self._unmatched_source_nodes + and id(target_leaf) in self._unmatched_target_nodes + ): + matching_set.add((id(source_leaf), id(target_leaf))) + self._unmatched_source_nodes.remove(id(source_leaf)) + self._unmatched_target_nodes.remove(id(target_leaf)) + + return matching_set + + def _dice_coefficient( + self, source: exp.Expression, target: exp.Expression + ) -> float: + source_histo = self._bigram_histo(source) + target_histo = self._bigram_histo(target) + + total_grams = sum(source_histo.values()) + sum(target_histo.values()) + if not total_grams: + return 1.0 if source == target else 0.0 + + overlap_len = 0 + overlapping_grams = set(source_histo) & set(target_histo) + for g in overlapping_grams: + overlap_len += min(source_histo[g], target_histo[g]) + + return 2 * overlap_len / total_grams + + def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]: + if id(expression) in self._bigram_histo_cache: + return self._bigram_histo_cache[id(expression)] + + expression_str = self._sql_generator.generate(expression) + count = max(0, len(expression_str) - 1) + bigram_histo: t.DefaultDict[str, int] = defaultdict(int) + for i in range(count): + bigram_histo[expression_str[i : i + 2]] += 1 + + self._bigram_histo_cache[id(expression)] = bigram_histo + return bigram_histo + + +def _get_expression_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]: + has_child_exprs = False + + for node in expression.iter_expressions(): + if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES): + has_child_exprs = True + yield from _get_expression_leaves(node) + + if not has_child_exprs: + yield expression + + +def _get_non_expression_leaves( + expression: exp.Expression, +) -> t.Iterator[t.Tuple[str, t.Any]]: + for arg, value in expression.args.items(): + if ( + value is None + or isinstance(value, exp.Expression) + or ( + isinstance(value, list) + and isinstance(seq_get(value, 0), exp.Expression) + ) + ): + continue + + yield (arg, value) + + +def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: + if type(source) is type(target): + if isinstance(source, exp.Join): + return source.args.get("side") == target.args.get("side") + + if isinstance(source, exp.Anonymous): + return source.this == target.this + + return True + + return False + + +def _parent_similarity_score( + source: t.Optional[exp.Expression], target: t.Optional[exp.Expression] +) -> int: + if source is None or target is None or type(source) is not type(target): + return 0 + + return 1 + _parent_similarity_score(source.parent, target.parent) + + +def _expression_only_args(expression: exp.Expression) -> t.Iterator[exp.Expression]: + yield from ( + arg + for arg in expression.iter_expressions() + if not isinstance(arg, IGNORED_LEAF_EXPRESSION_TYPES) + ) + + +def _lcs( + seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool] +) -> t.Sequence[t.Optional[T]]: + """Calculates the longest common subsequence""" + + len_a = len(seq_a) + len_b = len(seq_b) + lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)] + + for i in range(len_a + 1): + for j in range(len_b + 1): + if i == 0 or j == 0: + lcs_result[i][j] = [] # type: ignore + elif equal(seq_a[i - 1], seq_b[j - 1]): + lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore + else: + lcs_result[i][j] = ( + lcs_result[i - 1][j] + if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore + else lcs_result[i][j - 1] + ) + + return lcs_result[len_a][len_b] # type: ignore diff --git a/third_party/bigframes_vendored/sqlglot/errors.py b/third_party/bigframes_vendored/sqlglot/errors.py new file mode 100644 index 0000000000..10e59362b1 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/errors.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from enum import auto +import typing as t + +from bigframes_vendored.sqlglot.helper import AutoName + +# ANSI escape codes for error formatting +ANSI_UNDERLINE = "\033[4m" +ANSI_RESET = "\033[0m" +ERROR_MESSAGE_CONTEXT_DEFAULT = 100 + + +class ErrorLevel(AutoName): + IGNORE = auto() + """Ignore all errors.""" + + WARN = auto() + """Log all errors.""" + + RAISE = auto() + """Collect all errors and raise a single exception.""" + + IMMEDIATE = auto() + """Immediately raise an exception on the first error found.""" + + +class SqlglotError(Exception): + pass + + +class UnsupportedError(SqlglotError): + pass + + +class ParseError(SqlglotError): + def __init__( + self, + message: str, + errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None, + ): + super().__init__(message) + self.errors = errors or [] + + @classmethod + def new( + cls, + message: str, + description: t.Optional[str] = None, + line: t.Optional[int] = None, + col: t.Optional[int] = None, + start_context: t.Optional[str] = None, + highlight: t.Optional[str] = None, + end_context: t.Optional[str] = None, + into_expression: t.Optional[str] = None, + ) -> ParseError: + return cls( + message, + [ + { + "description": description, + "line": line, + "col": col, + "start_context": start_context, + "highlight": highlight, + "end_context": end_context, + "into_expression": into_expression, + } + ], + ) + + +class TokenError(SqlglotError): + pass + + +class OptimizeError(SqlglotError): + pass + + +class SchemaError(SqlglotError): + pass + + +class ExecuteError(SqlglotError): + pass + + +def highlight_sql( + sql: str, + positions: t.List[t.Tuple[int, int]], + context_length: int = ERROR_MESSAGE_CONTEXT_DEFAULT, +) -> t.Tuple[str, str, str, str]: + """ + Highlight a SQL string using ANSI codes at the given positions. + + Args: + sql: The complete SQL string. + positions: List of (start, end) tuples where both start and end are inclusive 0-based + indexes. For example, to highlight "foo" in "SELECT foo", use (7, 9). + The positions will be sorted and de-duplicated if they overlap. + context_length: Number of characters to show before the first highlight and after + the last highlight. + + Returns: + A tuple of (formatted_sql, start_context, highlight, end_context) where: + - formatted_sql: The SQL with ANSI underline codes applied to highlighted sections + - start_context: Plain text before the first highlight + - highlight: Plain text from the first highlight start to the last highlight end, + including any non-highlighted text in between (no ANSI) + - end_context: Plain text after the last highlight + + Note: + If positions is empty, raises a ValueError. + """ + if not positions: + raise ValueError("positions must contain at least one (start, end) tuple") + + start_context = "" + end_context = "" + first_highlight_start = 0 + formatted_parts = [] + previous_part_end = 0 + sorted_positions = sorted(positions, key=lambda pos: pos[0]) + + if sorted_positions[0][0] > 0: + first_highlight_start = sorted_positions[0][0] + start_context = sql[ + max(0, first_highlight_start - context_length) : first_highlight_start + ] + formatted_parts.append(start_context) + previous_part_end = first_highlight_start + + for start, end in sorted_positions: + highlight_start = max(start, previous_part_end) + highlight_end = end + 1 + if highlight_start >= highlight_end: + continue # Skip invalid or overlapping highlights + if highlight_start > previous_part_end: + formatted_parts.append(sql[previous_part_end:highlight_start]) + formatted_parts.append( + f"{ANSI_UNDERLINE}{sql[highlight_start:highlight_end]}{ANSI_RESET}" + ) + previous_part_end = highlight_end + + if previous_part_end < len(sql): + end_context = sql[previous_part_end : previous_part_end + context_length] + formatted_parts.append(end_context) + + formatted_sql = "".join(formatted_parts) + highlight = sql[first_highlight_start:previous_part_end] + + return formatted_sql, start_context, highlight, end_context + + +def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str: + msg = [str(e) for e in errors[:maximum]] + remaining = len(errors) - maximum + if remaining > 0: + msg.append(f"... and {remaining} more") + return "\n\n".join(msg) + + +def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]: + return [e_dict for error in errors for e_dict in error.errors] diff --git a/third_party/bigframes_vendored/sqlglot/executor/__init__.py b/third_party/bigframes_vendored/sqlglot/executor/__init__.py new file mode 100644 index 0000000000..3bc3f8175a --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/executor/__init__.py @@ -0,0 +1,104 @@ +""" +.. include:: ../../posts/python_sql_engine.md + +---- +""" + +from __future__ import annotations + +import logging +import time +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.errors import ExecuteError +from bigframes_vendored.sqlglot.executor.python import PythonExecutor +from bigframes_vendored.sqlglot.optimizer import optimize +from bigframes_vendored.sqlglot.optimizer.annotate_types import annotate_types +from bigframes_vendored.sqlglot.planner import Plan +from bigframes_vendored.sqlglot.schema import ( + ensure_schema, + flatten_schema, + nested_get, + nested_set, +) +from bigframes_vendored.sqlglotes_vendored.sqlglot.executor.table import ( + ensure_tables, + Table, +) +from bigframes_vendored.sqlglotes_vendored.sqlglot.helper import dict_depth + +logger = logging.getLogger("sqlglot") + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + from bigframes_vendored.sqlglot.expressions import Expression + from bigframes_vendored.sqlglot.schema import Schema + + +def execute( + sql: str | Expression, + schema: t.Optional[t.Dict | Schema] = None, + dialect: DialectType = None, + tables: t.Optional[t.Dict] = None, +) -> Table: + """ + Run a sql query against data. + + Args: + sql: a sql statement. + schema: database schema. + This can either be an instance of `Schema` or a mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + dialect: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + tables: additional tables to register. + + Returns: + Simple columnar data structure. + """ + tables_ = ensure_tables(tables, dialect=dialect) + + if not schema: + schema = {} + flattened_tables = flatten_schema( + tables_.mapping, depth=dict_depth(tables_.mapping) + ) + + for keys in flattened_tables: + table = nested_get(tables_.mapping, *zip(keys, keys)) + assert table is not None + + for column in table.columns: + value = table[0][column] + column_type = ( + annotate_types(exp.convert(value), dialect=dialect).type + or type(value).__name__ + ) + nested_set(schema, [*keys, column], column_type) + + schema = ensure_schema(schema, dialect=dialect) + + if ( + tables_.supported_table_args + and tables_.supported_table_args != schema.supported_table_args + ): + raise ExecuteError("Tables must support the same table args as schema") + + now = time.time() + expression = optimize(sql, schema, leave_tables_isolated=True, dialect=dialect) + + logger.debug("Optimization finished: %f", time.time() - now) + logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) + + plan = Plan(expression) + + logger.debug("Logical Plan: %s", plan) + + now = time.time() + result = PythonExecutor(tables=tables_).execute(plan) + + logger.debug("Query finished: %f", time.time() - now) + + return result diff --git a/third_party/bigframes_vendored/sqlglot/executor/context.py b/third_party/bigframes_vendored/sqlglot/executor/context.py new file mode 100644 index 0000000000..630fc2d357 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/executor/context.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot.executor.env import ENV + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.executor.table import Table, TableIter + + +class Context: + """ + Execution context for sql expressions. + + Context is used to hold relevant data tables which can then be queried on with eval. + + References to columns can either be scalar or vectors. When set_row is used, column references + evaluate to scalars while set_range evaluates to vectors. This allows convenient and efficient + evaluation of aggregation functions. + """ + + def __init__( + self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None + ) -> None: + """ + Args + tables: representing the scope of the current execution context. + env: dictionary of functions within the execution context. + """ + self.tables = tables + self._table: t.Optional[Table] = None + self.range_readers = { + name: table.range_reader for name, table in self.tables.items() + } + self.row_readers = {name: table.reader for name, table in tables.items()} + self.env = {**ENV, **(env or {}), "scope": self.row_readers} + + def eval(self, code): + return eval(code, self.env) + + def eval_tuple(self, codes): + return tuple(self.eval(code) for code in codes) + + @property + def table(self) -> Table: + if self._table is None: + self._table = list(self.tables.values())[0] + + for other in self.tables.values(): + if self._table.columns != other.columns: + raise Exception("Columns are different.") + if len(self._table.rows) != len(other.rows): + raise Exception("Rows are different.") + + return self._table + + def add_columns(self, *columns: str) -> None: + for table in self.tables.values(): + table.add_columns(*columns) + + @property + def columns(self) -> t.Tuple: + return self.table.columns + + def __iter__(self): + self.env["scope"] = self.row_readers + for i in range(len(self.table.rows)): + for table in self.tables.values(): + reader = table[i] + yield reader, self + + def table_iter(self, table: str) -> TableIter: + self.env["scope"] = self.row_readers + return iter(self.tables[table]) + + def filter(self, condition) -> None: + rows = [reader.row for reader, _ in self if self.eval(condition)] + + for table in self.tables.values(): + table.rows = rows + + def sort(self, key) -> None: + def sort_key(row: t.Tuple) -> t.Tuple: + self.set_row(row) + return tuple((t is None, t) for t in self.eval_tuple(key)) + + self.table.rows.sort(key=sort_key) + + def set_row(self, row: t.Tuple) -> None: + for table in self.tables.values(): + table.reader.row = row + self.env["scope"] = self.row_readers + + def set_index(self, index: int) -> None: + for table in self.tables.values(): + table[index] + self.env["scope"] = self.row_readers + + def set_range(self, start: int, end: int) -> None: + for name in self.tables: + self.range_readers[name].range = range(start, end) + self.env["scope"] = self.range_readers + + def __contains__(self, table: str) -> bool: + return table in self.tables diff --git a/third_party/bigframes_vendored/sqlglot/executor/env.py b/third_party/bigframes_vendored/sqlglot/executor/env.py new file mode 100644 index 0000000000..c3688c09a6 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/executor/env.py @@ -0,0 +1,258 @@ +import datetime +from functools import wraps +import inspect +import re +import statistics + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.generator import Generator +from bigframes_vendored.sqlglot.helper import is_int, PYTHON_VERSION, seq_get + + +class reverse_key: + def __init__(self, obj): + self.obj = obj + + def __eq__(self, other): + return other.obj == self.obj + + def __lt__(self, other): + return other.obj < self.obj + + +def filter_nulls(func, empty_null=True): + @wraps(func) + def _func(values): + filtered = tuple(v for v in values if v is not None) + if not filtered and empty_null: + return None + return func(filtered) + + return _func + + +def null_if_any(*required): + """ + Decorator that makes a function return `None` if any of the `required` arguments are `None`. + + This also supports decoration with no arguments, e.g.: + + @null_if_any + def foo(a, b): ... + + In which case all arguments are required. + """ + f = None + if len(required) == 1 and callable(required[0]): + f = required[0] + required = () + + def decorator(func): + if required: + required_indices = [ + i + for i, param in enumerate(inspect.signature(func).parameters) + if param in required + ] + + def predicate(*args): + return any(args[i] is None for i in required_indices) + + else: + + def predicate(*args): + return any(a is None for a in args) + + @wraps(func) + def _func(*args): + if predicate(*args): + return None + return func(*args) + + return _func + + if f: + return decorator(f) + + return decorator + + +@null_if_any("this", "substr") +def str_position(this, substr, position=None): + position = position - 1 if position is not None else position + return this.find(substr, position) + 1 + + +@null_if_any("this") +def substring(this, start=None, length=None): + if start is None: + return this + elif start == 0: + return "" + elif start < 0: + start = len(this) + start + else: + start -= 1 + + end = None if length is None else start + length + + return this[start:end] + + +@null_if_any +def cast(this, to): + if to == exp.DataType.Type.DATE: + if isinstance(this, datetime.datetime): + return this.date() + if isinstance(this, datetime.date): + return this + if isinstance(this, str): + return datetime.date.fromisoformat(this) + if to == exp.DataType.Type.TIME: + if isinstance(this, datetime.datetime): + return this.time() + if isinstance(this, datetime.time): + return this + if isinstance(this, str): + return datetime.time.fromisoformat(this) + if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP): + if isinstance(this, datetime.datetime): + return this + if isinstance(this, datetime.date): + return datetime.datetime(this.year, this.month, this.day) + if isinstance(this, str): + return datetime.datetime.fromisoformat(this) + if to == exp.DataType.Type.BOOLEAN: + return bool(this) + if to in exp.DataType.TEXT_TYPES: + return str(this) + if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}: + return float(this) + if to in exp.DataType.NUMERIC_TYPES: + return int(this) + raise NotImplementedError(f"Casting {this} to '{to}' not implemented.") + + +def ordered(this, desc, nulls_first): + if desc: + return reverse_key(this) + return this + + +@null_if_any +def interval(this, unit): + plural = unit + "S" + if plural in Generator.TIME_PART_SINGULARS: + unit = plural + return datetime.timedelta(**{unit.lower(): float(this)}) + + +@null_if_any("this", "expression") +def arraytostring(this, expression, null=None): + return expression.join( + x for x in (x if x is not None else null for x in this) if x is not None + ) + + +@null_if_any("this", "expression") +def jsonextract(this, expression): + for path_segment in expression: + if isinstance(this, dict): + this = this.get(path_segment) + elif isinstance(this, list) and is_int(path_segment): + this = seq_get(this, int(path_segment)) + else: + raise NotImplementedError( + f"Unable to extract value for {this} at {path_segment}." + ) + + if this is None: + break + + return this + + +ENV = { + "exp": exp, + # aggs + "ARRAYAGG": list, + "ARRAYUNIQUEAGG": filter_nulls(lambda acc: list(set(acc))), + "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore + "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False), + "MAX": filter_nulls(max), + "MIN": filter_nulls(min), + "SUM": filter_nulls(sum), + # scalar functions + "ABS": null_if_any(lambda this: abs(this)), + "ADD": null_if_any(lambda e, this: e + this), + "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)), + "ARRAYTOSTRING": arraytostring, + "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), + "BITWISEAND": null_if_any(lambda this, e: this & e), + "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), + "BITWISEOR": null_if_any(lambda this, e: this | e), + "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e), + "BITWISEXOR": null_if_any(lambda this, e: this ^ e), + "CAST": cast, + "COALESCE": lambda *args: next((a for a in args if a is not None), None), + "CONCAT": null_if_any(lambda *args: "".join(args)), + "SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)), + "CONCATWS": null_if_any(lambda this, *args: this.join(args)), + "DATEDIFF": null_if_any(lambda this, expression, *_: (this - expression).days), + "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)), + "DIV": null_if_any(lambda e, this: e / this), + "DOT": null_if_any(lambda e, this: e[this]), + "EQ": null_if_any(lambda this, e: this == e), + "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), + "GT": null_if_any(lambda this, e: this > e), + "GTE": null_if_any(lambda this, e: this >= e), + "IF": lambda predicate, true, false: true if predicate else false, + "INTDIV": null_if_any(lambda e, this: e // this), + "INTERVAL": interval, + "JSONEXTRACT": jsonextract, + "LEFT": null_if_any(lambda this, e: this[:e]), + "LIKE": null_if_any( + lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this)) + ), + "LOWER": null_if_any(lambda arg: arg.lower()), + "LT": null_if_any(lambda this, e: this < e), + "LTE": null_if_any(lambda this, e: this <= e), + "MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore + "MOD": null_if_any(lambda e, this: e % this), + "MUL": null_if_any(lambda e, this: e * this), + "NEQ": null_if_any(lambda this, e: this != e), + "ORD": null_if_any(ord), + "ORDERED": ordered, + "POW": pow, + "RIGHT": null_if_any(lambda this, e: this[-e:]), + "ROUND": null_if_any( + lambda this, decimals=None, truncate=None: round(this, ndigits=decimals) + ), + "STRPOSITION": str_position, + "SUB": null_if_any(lambda e, this: e - this), + "SUBSTRING": substring, + "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)), + "UPPER": null_if_any(lambda arg: arg.upper()), + "YEAR": null_if_any(lambda arg: arg.year), + "MONTH": null_if_any(lambda arg: arg.month), + "DAY": null_if_any(lambda arg: arg.day), + "CURRENTDATETIME": datetime.datetime.now, + "CURRENTTIMESTAMP": datetime.datetime.now, + "CURRENTTIME": datetime.datetime.now, + "CURRENTDATE": datetime.date.today, + "STRFTIME": null_if_any( + lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt) + ), + "STRTOTIME": null_if_any( + lambda arg, format: datetime.datetime.strptime(arg, format) + ), + "TRIM": null_if_any(lambda this, e=None: this.strip(e)), + "STRUCT": lambda *args: { + args[x]: args[x + 1] + for x in range(0, len(args), 2) + if (args[x + 1] is not None and args[x] is not None) + }, + "UNIXTOTIME": null_if_any( + lambda arg: datetime.datetime.fromtimestamp(arg, datetime.timezone.utc) + ), +} diff --git a/third_party/bigframes_vendored/sqlglot/executor/python.py b/third_party/bigframes_vendored/sqlglot/executor/python.py new file mode 100644 index 0000000000..cea7f0fc56 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/executor/python.py @@ -0,0 +1,447 @@ +import collections +import itertools +import math + +from bigframes_vendored.sqlglot import exp, generator, planner, tokens +from bigframes_vendored.sqlglot.dialects.dialect import Dialect, inline_array_sql +from bigframes_vendored.sqlglot.errors import ExecuteError +from bigframes_vendored.sqlglot.executor.context import Context +from bigframes_vendored.sqlglot.executor.env import ENV +from bigframes_vendored.sqlglot.executor.table import RowReader, Table +from bigframes_vendored.sqlglot.helper import subclasses + + +class PythonExecutor: + def __init__(self, env=None, tables=None): + self.generator = Python().generator(identify=True, comments=False) + self.env = {**ENV, **(env or {})} + self.tables = tables or {} + + def execute(self, plan): + finished = set() + queue = set(plan.leaves) + contexts = {} + + while queue: + node = queue.pop() + try: + context = self.context( + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } + ) + + if isinstance(node, planner.Scan): + contexts[node] = self.scan(node, context) + elif isinstance(node, planner.Aggregate): + contexts[node] = self.aggregate(node, context) + elif isinstance(node, planner.Join): + contexts[node] = self.join(node, context) + elif isinstance(node, planner.Sort): + contexts[node] = self.sort(node, context) + elif isinstance(node, planner.SetOperation): + contexts[node] = self.set_operation(node, context) + else: + raise NotImplementedError + + finished.add(node) + + for dep in node.dependents: + if all(d in contexts for d in dep.dependencies): + queue.add(dep) + + for dep in node.dependencies: + if all(d in finished for d in dep.dependents): + contexts.pop(dep) + except Exception as e: + raise ExecuteError(f"Step '{node.id}' failed: {e}") from e + + root = plan.root + return contexts[root].tables[root.name] + + def generate(self, expression): + """Convert a SQL expression into literal Python code and compile it into bytecode.""" + if not expression: + return None + + sql = self.generator.generate(expression) + return compile(sql, sql, "eval", optimize=2) + + def generate_tuple(self, expressions): + """Convert an array of SQL expressions into tuple of Python byte code.""" + if not expressions: + return tuple() + return tuple(self.generate(expression) for expression in expressions) + + def context(self, tables): + return Context(tables, env=self.env) + + def table(self, expressions): + return Table( + expression.alias_or_name + if isinstance(expression, exp.Expression) + else expression + for expression in expressions + ) + + def scan(self, step, context): + source = step.source + + if source and isinstance(source, exp.Expression): + source = source.name or source.alias + + if source is None: + context, table_iter = self.static() + elif source in context: + if not step.projections and not step.condition: + return self.context({step.name: context.tables[source]}) + table_iter = context.table_iter(source) + else: + context, table_iter = self.scan_table(step) + + return self.context( + {step.name: self._project_and_filter(context, step, table_iter)} + ) + + def _project_and_filter(self, context, step, table_iter): + sink = self.table(step.projections if step.projections else context.columns) + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) + + for reader in table_iter: + if len(sink) >= step.limit: + break + + if condition and not context.eval(condition): + continue + + if projections: + sink.append(context.eval_tuple(projections)) + else: + sink.append(reader.row) + + return sink + + def static(self): + return self.context({}), [RowReader(())] + + def scan_table(self, step): + table = self.tables.find(step.source) + context = self.context({step.source.alias_or_name: table}) + return context, iter(table) + + def join(self, step, context): + source = step.source_name + + source_table = context.tables[source] + source_context = self.context({source: source_table}) + column_ranges = {source: range(0, len(source_table.columns))} + + for name, join in step.joins.items(): + table = context.tables[name] + start = max(r.stop for r in column_ranges.values()) + column_ranges[name] = range(start, len(table.columns) + start) + join_context = self.context({name: table}) + + if join.get("source_key"): + table = self.hash_join(join, source_context, join_context) + else: + table = self.nested_loop_join(join, source_context, join_context) + + source_context = self.context( + { + name: Table(table.columns, table.rows, column_range) + for name, column_range in column_ranges.items() + } + ) + condition = self.generate(join["condition"]) + if condition: + source_context.filter(condition) + + if not step.condition and not step.projections: + return source_context + + sink = self._project_and_filter( + source_context, + step, + (reader for reader, _ in iter(source_context)), + ) + + if step.projections: + return self.context({step.name: sink}) + else: + return self.context( + { + name: Table(table.columns, sink.rows, table.column_range) + for name, table in source_context.tables.items() + } + ) + + def nested_loop_join(self, _join, source_context, join_context): + table = Table(source_context.columns + join_context.columns) + + for reader_a, _ in source_context: + for reader_b, _ in join_context: + table.append(reader_a.row + reader_b.row) + + return table + + def hash_join(self, join, source_context, join_context): + source_key = self.generate_tuple(join["source_key"]) + join_key = self.generate_tuple(join["join_key"]) + left = join.get("side") == "LEFT" + right = join.get("side") == "RIGHT" + + results = collections.defaultdict(lambda: ([], [])) + + for reader, ctx in source_context: + results[ctx.eval_tuple(source_key)][0].append(reader.row) + for reader, ctx in join_context: + results[ctx.eval_tuple(join_key)][1].append(reader.row) + + table = Table(source_context.columns + join_context.columns) + nulls = [ + (None,) * len(join_context.columns if left else source_context.columns) + ] + + for a_group, b_group in results.values(): + if left: + b_group = b_group or nulls + elif right: + a_group = a_group or nulls + + for a_row, b_row in itertools.product(a_group, b_group): + table.append(a_row + b_row) + + return table + + def aggregate(self, step, context): + group_by = self.generate_tuple(step.group.values()) + aggregations = self.generate_tuple(step.aggregations) + operands = self.generate_tuple(step.operands) + + if operands: + operand_table = Table(self.table(step.operands).columns) + + for reader, ctx in context: + operand_table.append(ctx.eval_tuple(operands)) + + for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)): + context.table.rows[i] = a + b + + width = len(context.columns) + context.add_columns(*operand_table.columns) + + operand_table = Table( + context.columns, + context.table.rows, + range(width, width + len(operand_table.columns)), + ) + + context = self.context( + { + None: operand_table, + **context.tables, + } + ) + + context.sort(group_by) + + group = None + start = 0 + end = 1 + length = len(context.table) + table = self.table(list(step.group) + step.aggregations) + + def add_row(): + table.append(group + context.eval_tuple(aggregations)) + + if length: + for i in range(length): + context.set_index(i) + key = context.eval_tuple(group_by) + group = key if group is None else group + end += 1 + if key != group: + context.set_range(start, end - 2) + add_row() + group = key + start = end - 2 + if len(table.rows) >= step.limit: + break + if i == length - 1: + context.set_range(start, end - 1) + add_row() + elif step.limit > 0 and not group_by: + context.set_range(0, 0) + table.append(context.eval_tuple(aggregations)) + + context = self.context( + {step.name: table, **{name: table for name in context.tables}} + ) + + if step.projections or step.condition: + return self.scan(step, context) + return context + + def sort(self, step, context): + projections = self.generate_tuple(step.projections) + projection_columns = [p.alias_or_name for p in step.projections] + all_columns = list(context.columns) + projection_columns + sink = self.table(all_columns) + for reader, ctx in context: + sink.append(reader.row + ctx.eval_tuple(projections)) + + sort_ctx = self.context( + { + None: sink, + **{table: sink for table in context.tables}, + } + ) + sort_ctx.sort(self.generate_tuple(step.key)) + + if not math.isinf(step.limit): + sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit] + + output = Table( + projection_columns, + rows=[ + r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows + ], + ) + return self.context({step.name: output}) + + def set_operation(self, step, context): + left = context.tables[step.left] + right = context.tables[step.right] + + sink = self.table(left.columns) + + if issubclass(step.op, exp.Intersect): + sink.rows = list(set(left.rows).intersection(set(right.rows))) + elif issubclass(step.op, exp.Except): + sink.rows = list(set(left.rows).difference(set(right.rows))) + elif issubclass(step.op, exp.Union) and step.distinct: + sink.rows = list(set(left.rows).union(set(right.rows))) + else: + sink.rows = left.rows + right.rows + + if not math.isinf(step.limit): + sink.rows = sink.rows[0 : step.limit] + + return self.context({step.name: sink}) + + +def _ordered_py(self, expression): + this = self.sql(expression, "this") + desc = "True" if expression.args.get("desc") else "False" + nulls_first = "True" if expression.args.get("nulls_first") else "False" + return f"ORDERED({this}, {desc}, {nulls_first})" + + +def _rename(self, e): + try: + values = list(e.args.values()) + + if len(values) == 1: + values = values[0] + if not isinstance(values, list): + return self.func(e.key, values) + return self.func(e.key, *values) + + if isinstance(e, exp.Func) and e.is_var_len_args: + args = itertools.chain.from_iterable( + x if isinstance(x, list) else [x] for x in values + ) + return self.func(e.key, *args) + + return self.func(e.key, *values) + except Exception as ex: + raise Exception(f"Could not rename {repr(e)}") from ex + + +def _case_sql(self, expression): + this = self.sql(expression, "this") + chain = self.sql(expression, "default") or "None" + + for e in reversed(expression.args["ifs"]): + true = self.sql(e, "true") + condition = self.sql(e, "this") + condition = f"{this} = ({condition})" if this else condition + chain = f"{true} if {condition} else ({chain})" + + return chain + + +def _lambda_sql(self, e: exp.Lambda) -> str: + names = {e.name.lower() for e in e.expressions} + + e = e.transform( + lambda n: ( + exp.var(n.name) + if isinstance(n, exp.Identifier) and n.name.lower() in names + else n + ) + ).assert_is(exp.Lambda) + + return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}" + + +def _div_sql(self: generator.Generator, e: exp.Div) -> str: + denominator = self.sql(e, "expression") + + if e.args.get("safe"): + denominator += " or None" + + sql = f"DIV({self.sql(e, 'this')}, {denominator})" + + if e.args.get("typed"): + sql = f"int({sql})" + + return sql + + +class Python(Dialect): + class Tokenizer(tokens.Tokenizer): + STRING_ESCAPES = ["\\"] + + class Generator(generator.Generator): + TRANSFORMS = { + **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)}, + **{klass: _rename for klass in exp.ALL_FUNCTIONS}, + exp.Case: _case_sql, + exp.Alias: lambda self, e: self.sql(e.this), + exp.Array: inline_array_sql, + exp.And: lambda self, e: self.binary(e, "and"), + exp.Between: _rename, + exp.Boolean: lambda self, e: "True" if e.this else "False", + exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", + exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", + exp.Concat: lambda self, e: self.func( + "SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions + ), + exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", + exp.Div: _div_sql, + exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", + exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}", + exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", + exp.Is: lambda self, e: ( + self.binary(e, "==") + if isinstance(e.this, exp.Literal) + else self.binary(e, "is") + ), + exp.JSONExtract: lambda self, e: self.func( + e.key, e.this, e.expression, *e.expressions + ), + exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]", + exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'", + exp.JSONPathSubscript: lambda self, e: f"'{e.this}'", + exp.Lambda: _lambda_sql, + exp.Not: lambda self, e: f"not {self.sql(e.this)}", + exp.Null: lambda *_: "None", + exp.Or: lambda self, e: self.binary(e, "or"), + exp.Ordered: _ordered_py, + exp.Star: lambda *_: "1", + } diff --git a/third_party/bigframes_vendored/sqlglot/executor/table.py b/third_party/bigframes_vendored/sqlglot/executor/table.py new file mode 100644 index 0000000000..c42385ca3e --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/executor/table.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot.dialects.dialect import DialectType +from bigframes_vendored.sqlglot.helper import dict_depth +from bigframes_vendored.sqlglot.schema import AbstractMappingSchema, normalize_name + + +class Table: + def __init__( + self, + columns: t.Iterable, + rows: t.Optional[t.List] = None, + column_range: t.Optional[range] = None, + ) -> None: + self.columns = tuple(columns) + self.column_range = column_range + self.reader = RowReader(self.columns, self.column_range) + self.rows = rows or [] + if rows: + assert len(rows[0]) == len(self.columns) + self.range_reader = RangeReader(self) + + def add_columns(self, *columns: str) -> None: + self.columns += columns + if self.column_range: + self.column_range = range( + self.column_range.start, self.column_range.stop + len(columns) + ) + self.reader = RowReader(self.columns, self.column_range) + + def append(self, row: t.List) -> None: + assert len(row) == len(self.columns) + self.rows.append(row) + + def pop(self) -> None: + self.rows.pop() + + def to_pylist(self) -> t.List: + return [dict(zip(self.columns, row)) for row in self.rows] + + @property + def width(self) -> int: + return len(self.columns) + + def __len__(self) -> int: + return len(self.rows) + + def __iter__(self) -> TableIter: + return TableIter(self) + + def __getitem__(self, index: int) -> RowReader: + self.reader.row = self.rows[index] + return self.reader + + def __repr__(self) -> str: + columns = tuple( + column + for i, column in enumerate(self.columns) + if not self.column_range or i in self.column_range + ) + widths = {column: len(column) for column in columns} + lines = [" ".join(column for column in columns)] + + for i, row in enumerate(self): + if i > 10: + break + + lines.append( + " ".join( + str(row[column]).rjust(widths[column])[0 : widths[column]] + for column in columns + ) + ) + return "\n".join(lines) + + +class TableIter: + def __init__(self, table: Table) -> None: + self.table = table + self.index = -1 + + def __iter__(self) -> TableIter: + return self + + def __next__(self) -> RowReader: + self.index += 1 + if self.index < len(self.table): + return self.table[self.index] + raise StopIteration + + +class RangeReader: + def __init__(self, table: Table) -> None: + self.table = table + self.range = range(0) + + def __len__(self) -> int: + return len(self.range) + + def __getitem__(self, column: str): + return (self.table[i][column] for i in self.range) + + +class RowReader: + def __init__(self, columns, column_range=None): + self.columns = { + column: i + for i, column in enumerate(columns) + if not column_range or i in column_range + } + self.row = None + + def __getitem__(self, column): + return self.row[self.columns[column]] + + +class Tables(AbstractMappingSchema): + pass + + +def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables: + return Tables(_ensure_tables(d, dialect=dialect)) + + +def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict: + if not d: + return {} + + depth = dict_depth(d) + if depth > 1: + return { + normalize_name(k, dialect=dialect, is_table=True).name: _ensure_tables( + v, dialect=dialect + ) + for k, v in d.items() + } + + result = {} + for table_name, table in d.items(): + table_name = normalize_name(table_name, dialect=dialect).name + + if isinstance(table, Table): + result[table_name] = table + else: + table = [ + { + normalize_name(column_name, dialect=dialect).name: value + for column_name, value in row.items() + } + for row in table + ] + column_names = ( + tuple(column_name for column_name in table[0]) if table else () + ) + rows = [tuple(row[name] for name in column_names) for row in table] + result[table_name] = Table(columns=column_names, rows=rows) + + return result diff --git a/third_party/bigframes_vendored/sqlglot/expressions.py b/third_party/bigframes_vendored/sqlglot/expressions.py new file mode 100644 index 0000000000..c7a12889f2 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/expressions.py @@ -0,0 +1,10479 @@ +""" +## Expressions + +Every AST node in SQLGlot is represented by a subclass of `Expression`. + +This module contains the implementation of all supported `Expression` types. Additionally, +it exposes a number of helper functions, which are mainly used to programmatically build +SQL expressions, such as `sqlglot.expressions.select`. + +---- +""" + +from __future__ import annotations + +from collections import deque +from copy import deepcopy +import datetime +from decimal import Decimal +from enum import auto +from functools import reduce +import math +import numbers +import re +import sys +import textwrap +import typing as t + +from bigframes_vendored.sqlglot.errors import ErrorLevel, ParseError +from bigframes_vendored.sqlglot.helper import ( + AutoName, + camel_to_snake_case, + ensure_collection, + ensure_list, + seq_get, + split_num_words, + subclasses, + to_bool, +) +from bigframes_vendored.sqlglot.tokens import Token, TokenError + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E, Lit + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + from typing_extensions import Self + + Q = t.TypeVar("Q", bound="Query") + S = t.TypeVar("S", bound="SetOperation") + + +class _Expression(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + # When an Expression class is created, its key is automatically set + # to be the lowercase version of the class' name. + klass.key = clsname.lower() + klass.required_args = {k for k, v in klass.arg_types.items() if v} + + # This is so that docstrings are not inherited in pdoc + klass.__doc__ = klass.__doc__ or "" + + return klass + + +SQLGLOT_META = "sqlglot.meta" +SQLGLOT_ANONYMOUS = "sqlglot.anonymous" +TABLE_PARTS = ("this", "db", "catalog") +COLUMN_PARTS = ("this", "table", "db", "catalog") +POSITION_META_KEYS = ("line", "col", "start", "end") +UNITTEST = "unittest" in sys.modules or "pytest" in sys.modules + + +class Expression(metaclass=_Expression): + """ + The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary + context, such as its child expressions, their names (arg keys), and whether a given child expression + is optional or not. + + Attributes: + key: a unique key for each class in the Expression hierarchy. This is useful for hashing + and representing expressions as strings. + arg_types: determines the arguments (child nodes) supported by an expression. It maps + arg keys to booleans that indicate whether the corresponding args are optional. + parent: a reference to the parent expression (or None, in case of root expressions). + arg_key: the arg key an expression is associated with, i.e. the name its parent expression + uses to refer to it. + index: the index of an expression if it is inside of a list argument in its parent. + comments: a list of comments that are associated with a given expression. This is used in + order to preserve comments when transpiling SQL code. + type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the + optimizer, in order to enable some transformations that require type information. + meta: a dictionary that can be used to store useful metadata for a given expression. + + Example: + >>> class Foo(Expression): + ... arg_types = {"this": True, "expression": False} + + The above definition informs us that Foo is an Expression that requires an argument called + "this" and may also optionally receive an argument called "expression". + + Args: + args: a mapping used for retrieving the arguments of an expression, given their arg keys. + """ + + key = "expression" + arg_types = {"this": True} + required_args = {"this"} + __slots__ = ( + "args", + "parent", + "arg_key", + "index", + "comments", + "_type", + "_meta", + "_hash", + ) + + def __init__(self, **args: t.Any): + self.args: t.Dict[str, t.Any] = args + self.parent: t.Optional[Expression] = None + self.arg_key: t.Optional[str] = None + self.index: t.Optional[int] = None + self.comments: t.Optional[t.List[str]] = None + self._type: t.Optional[DataType] = None + self._meta: t.Optional[t.Dict[str, t.Any]] = None + self._hash: t.Optional[int] = None + + for arg_key, value in self.args.items(): + self._set_parent(arg_key, value) + + def __eq__(self, other) -> bool: + return self is other or ( + type(self) is type(other) and hash(self) == hash(other) + ) + + def __hash__(self) -> int: + if self._hash is None: + nodes = [] + queue = deque([self]) + + while queue: + node = queue.popleft() + nodes.append(node) + + for v in node.iter_expressions(): + if v._hash is None: + queue.append(v) + + for node in reversed(nodes): + hash_ = hash(node.key) + t = type(node) + + if t is Literal or t is Identifier: + for k, v in sorted(node.args.items()): + if v: + hash_ = hash((hash_, k, v)) + else: + for k, v in sorted(node.args.items()): + t = type(v) + + if t is list: + for x in v: + if x is not None and x is not False: + hash_ = hash( + (hash_, k, x.lower() if type(x) is str else x) + ) + else: + hash_ = hash((hash_, k)) + elif v is not None and v is not False: + hash_ = hash((hash_, k, v.lower() if t is str else v)) + + node._hash = hash_ + assert self._hash + return self._hash + + def __reduce__(self) -> t.Tuple[t.Callable, t.Tuple[t.List[t.Dict[str, t.Any]]]]: + from bigframes_vendored.sqlglot.serde import dump, load + + return (load, (dump(self),)) + + @property + def this(self) -> t.Any: + """ + Retrieves the argument with key "this". + """ + return self.args.get("this") + + @property + def expression(self) -> t.Any: + """ + Retrieves the argument with key "expression". + """ + return self.args.get("expression") + + @property + def expressions(self) -> t.List[t.Any]: + """ + Retrieves the argument with key "expressions". + """ + return self.args.get("expressions") or [] + + def text(self, key) -> str: + """ + Returns a textual representation of the argument corresponding to "key". This can only be used + for args that are strings or leaf Expression instances, such as identifiers and literals. + """ + field = self.args.get(key) + if isinstance(field, str): + return field + if isinstance(field, (Identifier, Literal, Var)): + return field.this + if isinstance(field, (Star, Null)): + return field.name + return "" + + @property + def is_string(self) -> bool: + """ + Checks whether a Literal expression is a string. + """ + return isinstance(self, Literal) and self.args["is_string"] + + @property + def is_number(self) -> bool: + """ + Checks whether a Literal expression is a number. + """ + return (isinstance(self, Literal) and not self.args["is_string"]) or ( + isinstance(self, Neg) and self.this.is_number + ) + + def to_py(self) -> t.Any: + """ + Returns a Python object equivalent of the SQL node. + """ + raise ValueError(f"{self} cannot be converted to a Python object.") + + @property + def is_int(self) -> bool: + """ + Checks whether an expression is an integer. + """ + return self.is_number and isinstance(self.to_py(), int) + + @property + def is_star(self) -> bool: + """Checks whether an expression is a star.""" + return isinstance(self, Star) or ( + isinstance(self, Column) and isinstance(self.this, Star) + ) + + @property + def alias(self) -> str: + """ + Returns the alias of the expression, or an empty string if it's not aliased. + """ + if isinstance(self.args.get("alias"), TableAlias): + return self.args["alias"].name + return self.text("alias") + + @property + def alias_column_names(self) -> t.List[str]: + table_alias = self.args.get("alias") + if not table_alias: + return [] + return [c.name for c in table_alias.args.get("columns") or []] + + @property + def name(self) -> str: + return self.text("this") + + @property + def alias_or_name(self) -> str: + return self.alias or self.name + + @property + def output_name(self) -> str: + """ + Name of the output column if this expression is a selection. + + If the Expression has no output name, an empty string is returned. + + Example: + >>> from sqlglot import parse_one + >>> parse_one("SELECT a").expressions[0].output_name + 'a' + >>> parse_one("SELECT b AS c").expressions[0].output_name + 'c' + >>> parse_one("SELECT 1 + 2").expressions[0].output_name + '' + """ + return "" + + @property + def type(self) -> t.Optional[DataType]: + return self._type + + @type.setter + def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None: + if dtype and not isinstance(dtype, DataType): + dtype = DataType.build(dtype) + self._type = dtype # type: ignore + + def is_type(self, *dtypes) -> bool: + return self.type is not None and self.type.is_type(*dtypes) + + def is_leaf(self) -> bool: + return not any( + isinstance(v, (Expression, list)) and v for v in self.args.values() + ) + + @property + def meta(self) -> t.Dict[str, t.Any]: + if self._meta is None: + self._meta = {} + return self._meta + + def __deepcopy__(self, memo): + root = self.__class__() + stack = [(self, root)] + + while stack: + node, copy = stack.pop() + + if node.comments is not None: + copy.comments = deepcopy(node.comments) + if node._type is not None: + copy._type = deepcopy(node._type) + if node._meta is not None: + copy._meta = deepcopy(node._meta) + if node._hash is not None: + copy._hash = node._hash + + for k, vs in node.args.items(): + if hasattr(vs, "parent"): + stack.append((vs, vs.__class__())) + copy.set(k, stack[-1][-1]) + elif type(vs) is list: + copy.args[k] = [] + + for v in vs: + if hasattr(v, "parent"): + stack.append((v, v.__class__())) + copy.append(k, stack[-1][-1]) + else: + copy.append(k, v) + else: + copy.args[k] = vs + + return root + + def copy(self) -> Self: + """ + Returns a deep copy of the expression. + """ + return deepcopy(self) + + def add_comments( + self, comments: t.Optional[t.List[str]] = None, prepend: bool = False + ) -> None: + if self.comments is None: + self.comments = [] + + if comments: + for comment in comments: + _, *meta = comment.split(SQLGLOT_META) + if meta: + for kv in "".join(meta).split(","): + k, *v = kv.split("=") + value = v[0].strip() if v else True + self.meta[k.strip()] = to_bool(value) + + if not prepend: + self.comments.append(comment) + + if prepend: + self.comments = comments + self.comments + + def pop_comments(self) -> t.List[str]: + comments = self.comments or [] + self.comments = None + return comments + + def append(self, arg_key: str, value: t.Any) -> None: + """ + Appends value to arg_key if it's a list or sets it as a new list. + + Args: + arg_key (str): name of the list expression arg + value (Any): value to append to the list + """ + if type(self.args.get(arg_key)) is not list: + self.args[arg_key] = [] + self._set_parent(arg_key, value) + values = self.args[arg_key] + if hasattr(value, "parent"): + value.index = len(values) + values.append(value) + + def set( + self, + arg_key: str, + value: t.Any, + index: t.Optional[int] = None, + overwrite: bool = True, + ) -> None: + """ + Sets arg_key to value. + + Args: + arg_key: name of the expression arg. + value: value to set the arg to. + index: if the arg is a list, this specifies what position to add the value in it. + overwrite: assuming an index is given, this determines whether to overwrite the + list entry instead of only inserting a new value (i.e., like list.insert). + """ + expression: t.Optional[Expression] = self + + while expression and expression._hash is not None: + expression._hash = None + expression = expression.parent + + if index is not None: + expressions = self.args.get(arg_key) or [] + + if seq_get(expressions, index) is None: + return + if value is None: + expressions.pop(index) + for v in expressions[index:]: + v.index = v.index - 1 + return + + if isinstance(value, list): + expressions.pop(index) + expressions[index:index] = value + elif overwrite: + expressions[index] = value + else: + expressions.insert(index, value) + + value = expressions + elif value is None: + self.args.pop(arg_key, None) + return + + self.args[arg_key] = value + self._set_parent(arg_key, value, index) + + def _set_parent( + self, arg_key: str, value: t.Any, index: t.Optional[int] = None + ) -> None: + if hasattr(value, "parent"): + value.parent = self + value.arg_key = arg_key + value.index = index + elif type(value) is list: + for index, v in enumerate(value): + if hasattr(v, "parent"): + v.parent = self + v.arg_key = arg_key + v.index = index + + @property + def depth(self) -> int: + """ + Returns the depth of this tree. + """ + if self.parent: + return self.parent.depth + 1 + return 0 + + def iter_expressions(self, reverse: bool = False) -> t.Iterator[Expression]: + """Yields the key and expression for all arguments, exploding list args.""" + for vs in reversed(self.args.values()) if reverse else self.args.values(): # type: ignore + if type(vs) is list: + for v in reversed(vs) if reverse else vs: # type: ignore + if hasattr(v, "parent"): + yield v + elif hasattr(vs, "parent"): + yield vs + + def find(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Optional[E]: + """ + Returns the first node in this tree which matches at least one of + the specified types. + + Args: + expression_types: the expression type(s) to match. + bfs: whether to search the AST using the BFS algorithm (DFS is used if false). + + Returns: + The node which matches the criteria or None if no such node was found. + """ + return next(self.find_all(*expression_types, bfs=bfs), None) + + def find_all(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Iterator[E]: + """ + Returns a generator object which visits all nodes in this tree and only + yields those that match at least one of the specified expression types. + + Args: + expression_types: the expression type(s) to match. + bfs: whether to search the AST using the BFS algorithm (DFS is used if false). + + Returns: + The generator object. + """ + for expression in self.walk(bfs=bfs): + if isinstance(expression, expression_types): + yield expression + + def find_ancestor(self, *expression_types: t.Type[E]) -> t.Optional[E]: + """ + Returns a nearest parent matching expression_types. + + Args: + expression_types: the expression type(s) to match. + + Returns: + The parent node. + """ + ancestor = self.parent + while ancestor and not isinstance(ancestor, expression_types): + ancestor = ancestor.parent + return ancestor # type: ignore + + @property + def parent_select(self) -> t.Optional[Select]: + """ + Returns the parent select statement. + """ + return self.find_ancestor(Select) + + @property + def same_parent(self) -> bool: + """Returns if the parent is the same class as itself.""" + return type(self.parent) is self.__class__ + + def root(self) -> Expression: + """ + Returns the root expression of this tree. + """ + expression = self + while expression.parent: + expression = expression.parent + return expression + + def walk( + self, bfs: bool = True, prune: t.Optional[t.Callable[[Expression], bool]] = None + ) -> t.Iterator[Expression]: + """ + Returns a generator object which visits all nodes in this tree. + + Args: + bfs: if set to True the BFS traversal order will be applied, + otherwise the DFS traversal will be used instead. + prune: callable that returns True if the generator should stop traversing + this branch of the tree. + + Returns: + the generator object. + """ + if bfs: + yield from self.bfs(prune=prune) + else: + yield from self.dfs(prune=prune) + + def dfs( + self, prune: t.Optional[t.Callable[[Expression], bool]] = None + ) -> t.Iterator[Expression]: + """ + Returns a generator object which visits all nodes in this tree in + the DFS (Depth-first) order. + + Returns: + The generator object. + """ + stack = [self] + + while stack: + node = stack.pop() + + yield node + + if prune and prune(node): + continue + + for v in node.iter_expressions(reverse=True): + stack.append(v) + + def bfs( + self, prune: t.Optional[t.Callable[[Expression], bool]] = None + ) -> t.Iterator[Expression]: + """ + Returns a generator object which visits all nodes in this tree in + the BFS (Breadth-first) order. + + Returns: + The generator object. + """ + queue = deque([self]) + + while queue: + node = queue.popleft() + + yield node + + if prune and prune(node): + continue + + for v in node.iter_expressions(): + queue.append(v) + + def unnest(self): + """ + Returns the first non parenthesis child or self. + """ + expression = self + while type(expression) is Paren: + expression = expression.this + return expression + + def unalias(self): + """ + Returns the inner expression if this is an Alias. + """ + if isinstance(self, Alias): + return self.this + return self + + def unnest_operands(self): + """ + Returns unnested operands as a tuple. + """ + return tuple(arg.unnest() for arg in self.iter_expressions()) + + def flatten(self, unnest=True): + """ + Returns a generator which yields child nodes whose parents are the same class. + + A AND B AND C -> [A, B, C] + """ + for node in self.dfs( + prune=lambda n: n.parent and type(n) is not self.__class__ + ): + if type(node) is not self.__class__: + yield node.unnest() if unnest and not isinstance( + node, Subquery + ) else node + + def __str__(self) -> str: + return self.sql() + + def __repr__(self) -> str: + return _to_s(self) + + def to_s(self) -> str: + """ + Same as __repr__, but includes additional information which can be useful + for debugging, like empty or missing args and the AST nodes' object IDs. + """ + return _to_s(self, verbose=True) + + def sql(self, dialect: DialectType = None, **opts) -> str: + """ + Returns SQL string representation of this tree. + + Args: + dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql"). + opts: other `sqlglot.generator.Generator` options. + + Returns: + The SQL string. + """ + from bigframes_vendored.sqlglot.dialects import Dialect + + return Dialect.get_or_raise(dialect).generate(self, **opts) + + def transform( + self, fun: t.Callable, *args: t.Any, copy: bool = True, **kwargs + ) -> Expression: + """ + Visits all tree nodes (excluding already transformed ones) + and applies the given transformation function to each node. + + Args: + fun: a function which takes a node as an argument and returns a + new transformed node or the same node without modifications. If the function + returns None, then the corresponding node will be removed from the syntax tree. + copy: if set to True a new tree instance is constructed, otherwise the tree is + modified in place. + + Returns: + The transformed tree. + """ + root = None + new_node = None + + for node in (self.copy() if copy else self).dfs( + prune=lambda n: n is not new_node + ): + parent, arg_key, index = node.parent, node.arg_key, node.index + new_node = fun(node, *args, **kwargs) + + if not root: + root = new_node + elif parent and arg_key and new_node is not node: + parent.set(arg_key, new_node, index) + + assert root + return root.assert_is(Expression) + + @t.overload + def replace(self, expression: E) -> E: + ... + + @t.overload + def replace(self, expression: None) -> None: + ... + + def replace(self, expression): + """ + Swap out this expression with a new expression. + + For example:: + + >>> tree = Select().select("x").from_("tbl") + >>> tree.find(Column).replace(column("y")) + Column( + this=Identifier(this=y, quoted=False)) + >>> tree.sql() + 'SELECT y FROM tbl' + + Args: + expression: new node + + Returns: + The new expression or expressions. + """ + parent = self.parent + + if not parent or parent is expression: + return expression + + key = self.arg_key + value = parent.args.get(key) + + if type(expression) is list and isinstance(value, Expression): + # We are trying to replace an Expression with a list, so it's assumed that + # the intention was to really replace the parent of this expression. + value.parent.replace(expression) + else: + parent.set(key, expression, self.index) + + if expression is not self: + self.parent = None + self.arg_key = None + self.index = None + + return expression + + def pop(self: E) -> E: + """ + Remove this expression from its AST. + + Returns: + The popped expression. + """ + self.replace(None) + return self + + def assert_is(self, type_: t.Type[E]) -> E: + """ + Assert that this `Expression` is an instance of `type_`. + + If it is NOT an instance of `type_`, this raises an assertion error. + Otherwise, this returns this expression. + + Examples: + This is useful for type security in chained expressions: + + >>> import sqlglot + >>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql() + 'SELECT x, z FROM y' + """ + if not isinstance(self, type_): + raise AssertionError(f"{self} is not {type_}.") + return self + + def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: + """ + Checks if this expression is valid (e.g. all mandatory args are set). + + Args: + args: a sequence of values that were used to instantiate a Func expression. This is used + to check that the provided arguments don't exceed the function argument limit. + + Returns: + A list of error messages for all possible errors that were found. + """ + errors: t.List[str] = [] + + if UNITTEST: + for k in self.args: + if k not in self.arg_types: + raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}") + + for k in self.required_args: + v = self.args.get(k) + if v is None or (type(v) is list and not v): + errors.append(f"Required keyword: '{k}' missing for {self.__class__}") + + if ( + args + and isinstance(self, Func) + and len(args) > len(self.arg_types) + and not self.is_var_len_args + ): + errors.append( + f"The number of provided arguments ({len(args)}) is greater than " + f"the maximum number of supported arguments ({len(self.arg_types)})" + ) + + return errors + + def dump(self): + """ + Dump this Expression to a JSON-serializable dict. + """ + from bigframes_vendored.sqlglot.serde import dump + + return dump(self) + + @classmethod + def load(cls, obj): + """ + Load a dict (as returned by `Expression.dump`) into an Expression instance. + """ + from bigframes_vendored.sqlglot.serde import load + + return load(obj) + + def and_( + self, + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + wrap: bool = True, + **opts, + ) -> Condition: + """ + AND this condition with one or multiple expressions. + + Example: + >>> condition("x=1").and_("y=1").sql() + 'x = 1 AND y = 1' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect: the dialect used to parse the input expression. + copy: whether to copy the involved expressions (only applies to Expressions). + wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid + precedence issues, but can be turned off when the produced AST is too deep and + causes recursion-related issues. + opts: other options to use to parse the input expressions. + + Returns: + The new And condition. + """ + return and_(self, *expressions, dialect=dialect, copy=copy, wrap=wrap, **opts) + + def or_( + self, + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + wrap: bool = True, + **opts, + ) -> Condition: + """ + OR this condition with one or multiple expressions. + + Example: + >>> condition("x=1").or_("y=1").sql() + 'x = 1 OR y = 1' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect: the dialect used to parse the input expression. + copy: whether to copy the involved expressions (only applies to Expressions). + wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid + precedence issues, but can be turned off when the produced AST is too deep and + causes recursion-related issues. + opts: other options to use to parse the input expressions. + + Returns: + The new Or condition. + """ + return or_(self, *expressions, dialect=dialect, copy=copy, wrap=wrap, **opts) + + def not_(self, copy: bool = True): + """ + Wrap this condition with NOT. + + Example: + >>> condition("x=1").not_().sql() + 'NOT x = 1' + + Args: + copy: whether to copy this object. + + Returns: + The new Not instance. + """ + return not_(self, copy=copy) + + def update_positions( + self: E, + other: t.Optional[Token | Expression] = None, + line: t.Optional[int] = None, + col: t.Optional[int] = None, + start: t.Optional[int] = None, + end: t.Optional[int] = None, + ) -> E: + """ + Update this expression with positions from a token or other expression. + + Args: + other: a token or expression to update this expression with. + line: the line number to use if other is None + col: column number + start: start char index + end: end char index + + Returns: + The updated expression. + """ + if other is None: + self.meta["line"] = line + self.meta["col"] = col + self.meta["start"] = start + self.meta["end"] = end + elif hasattr(other, "meta"): + for k in POSITION_META_KEYS: + self.meta[k] = other.meta[k] + else: + self.meta["line"] = other.line + self.meta["col"] = other.col + self.meta["start"] = other.start + self.meta["end"] = other.end + return self + + def as_( + self, + alias: str | Identifier, + quoted: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Alias: + return alias_(self, alias, quoted=quoted, dialect=dialect, copy=copy, **opts) + + def _binop(self, klass: t.Type[E], other: t.Any, reverse: bool = False) -> E: + this = self.copy() + other = convert(other, copy=True) + if not isinstance(this, klass) and not isinstance(other, klass): + this = _wrap(this, Binary) + other = _wrap(other, Binary) + if reverse: + return klass(this=other, expression=this) + return klass(this=this, expression=other) + + def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]) -> Bracket: + return Bracket( + this=self.copy(), + expressions=[convert(e, copy=True) for e in ensure_list(other)], + ) + + def __iter__(self) -> t.Iterator: + if "expressions" in self.arg_types: + return iter(self.args.get("expressions") or []) + # We define this because __getitem__ converts Expression into an iterable, which is + # problematic because one can hit infinite loops if they do "for x in some_expr: ..." + # See: https://peps.python.org/pep-0234/ + raise TypeError(f"'{self.__class__.__name__}' object is not iterable") + + def isin( + self, + *expressions: t.Any, + query: t.Optional[ExpOrStr] = None, + unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None, + copy: bool = True, + **opts, + ) -> In: + subquery = maybe_parse(query, copy=copy, **opts) if query else None + if subquery and not isinstance(subquery, Subquery): + subquery = subquery.subquery(copy=False) + + return In( + this=maybe_copy(self, copy), + expressions=[convert(e, copy=copy) for e in expressions], + query=subquery, + unnest=( + Unnest( + expressions=[ + maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) + for e in ensure_list(unnest) + ] + ) + if unnest + else None + ), + ) + + def between( + self, + low: t.Any, + high: t.Any, + copy: bool = True, + symmetric: t.Optional[bool] = None, + **opts, + ) -> Between: + between = Between( + this=maybe_copy(self, copy), + low=convert(low, copy=copy, **opts), + high=convert(high, copy=copy, **opts), + ) + if symmetric is not None: + between.set("symmetric", symmetric) + + return between + + def is_(self, other: ExpOrStr) -> Is: + return self._binop(Is, other) + + def like(self, other: ExpOrStr) -> Like: + return self._binop(Like, other) + + def ilike(self, other: ExpOrStr) -> ILike: + return self._binop(ILike, other) + + def eq(self, other: t.Any) -> EQ: + return self._binop(EQ, other) + + def neq(self, other: t.Any) -> NEQ: + return self._binop(NEQ, other) + + def rlike(self, other: ExpOrStr) -> RegexpLike: + return self._binop(RegexpLike, other) + + def div(self, other: ExpOrStr, typed: bool = False, safe: bool = False) -> Div: + div = self._binop(Div, other) + div.set("typed", typed) + div.set("safe", safe) + return div + + def asc(self, nulls_first: bool = True) -> Ordered: + return Ordered(this=self.copy(), nulls_first=nulls_first) + + def desc(self, nulls_first: bool = False) -> Ordered: + return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first) + + def __lt__(self, other: t.Any) -> LT: + return self._binop(LT, other) + + def __le__(self, other: t.Any) -> LTE: + return self._binop(LTE, other) + + def __gt__(self, other: t.Any) -> GT: + return self._binop(GT, other) + + def __ge__(self, other: t.Any) -> GTE: + return self._binop(GTE, other) + + def __add__(self, other: t.Any) -> Add: + return self._binop(Add, other) + + def __radd__(self, other: t.Any) -> Add: + return self._binop(Add, other, reverse=True) + + def __sub__(self, other: t.Any) -> Sub: + return self._binop(Sub, other) + + def __rsub__(self, other: t.Any) -> Sub: + return self._binop(Sub, other, reverse=True) + + def __mul__(self, other: t.Any) -> Mul: + return self._binop(Mul, other) + + def __rmul__(self, other: t.Any) -> Mul: + return self._binop(Mul, other, reverse=True) + + def __truediv__(self, other: t.Any) -> Div: + return self._binop(Div, other) + + def __rtruediv__(self, other: t.Any) -> Div: + return self._binop(Div, other, reverse=True) + + def __floordiv__(self, other: t.Any) -> IntDiv: + return self._binop(IntDiv, other) + + def __rfloordiv__(self, other: t.Any) -> IntDiv: + return self._binop(IntDiv, other, reverse=True) + + def __mod__(self, other: t.Any) -> Mod: + return self._binop(Mod, other) + + def __rmod__(self, other: t.Any) -> Mod: + return self._binop(Mod, other, reverse=True) + + def __pow__(self, other: t.Any) -> Pow: + return self._binop(Pow, other) + + def __rpow__(self, other: t.Any) -> Pow: + return self._binop(Pow, other, reverse=True) + + def __and__(self, other: t.Any) -> And: + return self._binop(And, other) + + def __rand__(self, other: t.Any) -> And: + return self._binop(And, other, reverse=True) + + def __or__(self, other: t.Any) -> Or: + return self._binop(Or, other) + + def __ror__(self, other: t.Any) -> Or: + return self._binop(Or, other, reverse=True) + + def __neg__(self) -> Neg: + return Neg(this=_wrap(self.copy(), Binary)) + + def __invert__(self) -> Not: + return not_(self.copy()) + + +IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], +] +ExpOrStr = t.Union[str, Expression] + + +class Condition(Expression): + """Logical conditions like x AND y, or simply x""" + + +class Predicate(Condition): + """Relationships like x = y, x > 1, x >= y.""" + + +class DerivedTable(Expression): + @property + def selects(self) -> t.List[Expression]: + return self.this.selects if isinstance(self.this, Query) else [] + + @property + def named_selects(self) -> t.List[str]: + return [select.output_name for select in self.selects] + + +class Query(Expression): + def subquery( + self, alias: t.Optional[ExpOrStr] = None, copy: bool = True + ) -> Subquery: + """ + Returns a `Subquery` that wraps around this query. + + Example: + >>> subquery = Select().select("x").from_("tbl").subquery() + >>> Select().select("x").from_(subquery).sql() + 'SELECT x FROM (SELECT x FROM tbl)' + + Args: + alias: an optional alias for the subquery. + copy: if `False`, modify this expression instance in-place. + """ + instance = maybe_copy(self, copy) + if not isinstance(alias, Expression): + alias = TableAlias(this=to_identifier(alias)) if alias else None + + return Subquery(this=instance, alias=alias) + + def limit( + self: Q, + expression: ExpOrStr | int, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Q: + """ + Adds a LIMIT clause to this query. + + Example: + >>> select("1").union(select("1")).limit(1).sql() + 'SELECT 1 UNION SELECT 1 LIMIT 1' + + Args: + expression: the SQL code string to parse. + This can also be an integer. + If a `Limit` instance is passed, it will be used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Limit`. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + A limited Select expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="limit", + into=Limit, + prefix="LIMIT", + dialect=dialect, + copy=copy, + into_arg="expression", + **opts, + ) + + def offset( + self: Q, + expression: ExpOrStr | int, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Q: + """ + Set the OFFSET expression. + + Example: + >>> Select().from_("tbl").select("x").offset(10).sql() + 'SELECT x FROM tbl OFFSET 10' + + Args: + expression: the SQL code string to parse. + This can also be an integer. + If a `Offset` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Offset`. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="offset", + into=Offset, + prefix="OFFSET", + dialect=dialect, + copy=copy, + into_arg="expression", + **opts, + ) + + def order_by( + self: Q, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Q: + """ + Set the ORDER BY expression. + + Example: + >>> Select().from_("tbl").select("x").order_by("x DESC").sql() + 'SELECT x FROM tbl ORDER BY x DESC' + + Args: + *expressions: the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Order`. + append: if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="order", + append=append, + copy=copy, + prefix="ORDER BY", + into=Order, + dialect=dialect, + **opts, + ) + + @property + def ctes(self) -> t.List[CTE]: + """Returns a list of all the CTEs attached to this query.""" + with_ = self.args.get("with_") + return with_.expressions if with_ else [] + + @property + def selects(self) -> t.List[Expression]: + """Returns the query's projections.""" + raise NotImplementedError("Query objects must implement `selects`") + + @property + def named_selects(self) -> t.List[str]: + """Returns the output names of the query's projections.""" + raise NotImplementedError("Query objects must implement `named_selects`") + + def select( + self: Q, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Q: + """ + Append to or set the SELECT expressions. + + Example: + >>> Select().select("x", "y").sql() + 'SELECT x, y' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Query expression. + """ + raise NotImplementedError("Query objects must implement `select`") + + def where( + self: Q, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Q: + """ + Append to or set the WHERE expressions. + + Examples: + >>> Select().select("x").from_("tbl").where("x = 'a' OR x < 'b'").sql() + "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'" + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append: if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_conjunction_builder( + *[expr.this if isinstance(expr, Where) else expr for expr in expressions], + instance=self, + arg="where", + append=append, + into=Where, + dialect=dialect, + copy=copy, + **opts, + ) + + def with_( + self: Q, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + materialized: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + scalar: t.Optional[bool] = None, + **opts, + ) -> Q: + """ + Append to or set the common table expressions. + + Example: + >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql() + 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' + + Args: + alias: the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_: the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + materialized: set the MATERIALIZED part of the expression. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + scalar: if `True`, this is a scalar common table expression. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_cte_builder( + self, + alias, + as_, + recursive=recursive, + materialized=materialized, + append=append, + dialect=dialect, + copy=copy, + scalar=scalar, + **opts, + ) + + def union( + self, + *expressions: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + **opts, + ) -> Union: + """ + Builds a UNION expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql() + 'SELECT * FROM foo UNION SELECT * FROM bla' + + Args: + expressions: the SQL code strings. + If `Expression` instances are passed, they will be used as-is. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + + Returns: + The new Union expression. + """ + return union(self, *expressions, distinct=distinct, dialect=dialect, **opts) + + def intersect( + self, + *expressions: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + **opts, + ) -> Intersect: + """ + Builds an INTERSECT expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql() + 'SELECT * FROM foo INTERSECT SELECT * FROM bla' + + Args: + expressions: the SQL code strings. + If `Expression` instances are passed, they will be used as-is. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + + Returns: + The new Intersect expression. + """ + return intersect(self, *expressions, distinct=distinct, dialect=dialect, **opts) + + def except_( + self, + *expressions: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + **opts, + ) -> Except: + """ + Builds an EXCEPT expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql() + 'SELECT * FROM foo EXCEPT SELECT * FROM bla' + + Args: + expressions: the SQL code strings. + If `Expression` instance are passed, they will be used as-is. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + + Returns: + The new Except expression. + """ + return except_(self, *expressions, distinct=distinct, dialect=dialect, **opts) + + +class UDTF(DerivedTable): + @property + def selects(self) -> t.List[Expression]: + alias = self.args.get("alias") + return alias.columns if alias else [] + + +class Cache(Expression): + arg_types = { + "this": True, + "lazy": False, + "options": False, + "expression": False, + } + + +class Uncache(Expression): + arg_types = {"this": True, "exists": False} + + +class Refresh(Expression): + arg_types = {"this": True, "kind": True} + + +class DDL(Expression): + @property + def ctes(self) -> t.List[CTE]: + """Returns a list of all the CTEs attached to this statement.""" + with_ = self.args.get("with_") + return with_.expressions if with_ else [] + + @property + def selects(self) -> t.List[Expression]: + """If this statement contains a query (e.g. a CTAS), this returns the query's projections.""" + return self.expression.selects if isinstance(self.expression, Query) else [] + + @property + def named_selects(self) -> t.List[str]: + """ + If this statement contains a query (e.g. a CTAS), this returns the output + names of the query's projections. + """ + return ( + self.expression.named_selects if isinstance(self.expression, Query) else [] + ) + + +# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Data-Manipulation-Language/Statement-Syntax/LOCKING-Request-Modifier/LOCKING-Request-Modifier-Syntax +class LockingStatement(Expression): + arg_types = {"this": True, "expression": True} + + +class DML(Expression): + def returning( + self, + expression: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> "Self": + """ + Set the RETURNING expression. Not supported by all dialects. + + Example: + >>> delete("tbl").returning("*", dialect="postgres").sql() + 'DELETE FROM tbl RETURNING *' + + Args: + expression: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="returning", + prefix="RETURNING", + dialect=dialect, + copy=copy, + into=Returning, + **opts, + ) + + +class Create(DDL): + arg_types = { + "with_": False, + "this": True, + "kind": True, + "expression": False, + "exists": False, + "properties": False, + "replace": False, + "refresh": False, + "unique": False, + "indexes": False, + "no_schema_binding": False, + "begin": False, + "end": False, + "clone": False, + "concurrently": False, + "clustered": False, + } + + @property + def kind(self) -> t.Optional[str]: + kind = self.args.get("kind") + return kind and kind.upper() + + +class SequenceProperties(Expression): + arg_types = { + "increment": False, + "minvalue": False, + "maxvalue": False, + "cache": False, + "start": False, + "owned": False, + "options": False, + } + + +class TruncateTable(Expression): + arg_types = { + "expressions": True, + "is_database": False, + "exists": False, + "only": False, + "cluster": False, + "identity": False, + "option": False, + "partition": False, + } + + +# https://docs.snowflake.com/en/sql-reference/sql/create-clone +# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement +# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy +class Clone(Expression): + arg_types = {"this": True, "shallow": False, "copy": False} + + +class Describe(Expression): + arg_types = { + "this": True, + "style": False, + "kind": False, + "expressions": False, + "partition": False, + "format": False, + } + + +# https://duckdb.org/docs/sql/statements/attach.html#attach +class Attach(Expression): + arg_types = {"this": True, "exists": False, "expressions": False} + + +# https://duckdb.org/docs/sql/statements/attach.html#detach +class Detach(Expression): + arg_types = {"this": True, "exists": False} + + +# https://duckdb.org/docs/sql/statements/load_and_install.html +class Install(Expression): + arg_types = {"this": True, "from_": False, "force": False} + + +# https://duckdb.org/docs/guides/meta/summarize.html +class Summarize(Expression): + arg_types = {"this": True, "table": False} + + +class Kill(Expression): + arg_types = {"this": True, "kind": False} + + +class Pragma(Expression): + pass + + +class Declare(Expression): + arg_types = {"expressions": True} + + +class DeclareItem(Expression): + arg_types = {"this": True, "kind": False, "default": False} + + +class Set(Expression): + arg_types = {"expressions": False, "unset": False, "tag": False} + + +class Heredoc(Expression): + arg_types = {"this": True, "tag": False} + + +class SetItem(Expression): + arg_types = { + "this": False, + "expressions": False, + "kind": False, + "collate": False, # MySQL SET NAMES statement + "global_": False, + } + + +class QueryBand(Expression): + arg_types = {"this": True, "scope": False, "update": False} + + +class Show(Expression): + arg_types = { + "this": True, + "history": False, + "terse": False, + "target": False, + "offset": False, + "starts_with": False, + "limit": False, + "from_": False, + "like": False, + "where": False, + "db": False, + "scope": False, + "scope_kind": False, + "full": False, + "mutex": False, + "query": False, + "channel": False, + "global_": False, + "log": False, + "position": False, + "types": False, + "privileges": False, + "for_table": False, + "for_group": False, + "for_user": False, + "for_role": False, + "into_outfile": False, + "json": False, + } + + +class UserDefinedFunction(Expression): + arg_types = {"this": True, "expressions": False, "wrapped": False} + + +class CharacterSet(Expression): + arg_types = {"this": True, "default": False} + + +class RecursiveWithSearch(Expression): + arg_types = {"kind": True, "this": True, "expression": True, "using": False} + + +class With(Expression): + arg_types = {"expressions": True, "recursive": False, "search": False} + + @property + def recursive(self) -> bool: + return bool(self.args.get("recursive")) + + +class WithinGroup(Expression): + arg_types = {"this": True, "expression": False} + + +# clickhouse supports scalar ctes +# https://clickhouse.com/docs/en/sql-reference/statements/select/with +class CTE(DerivedTable): + arg_types = { + "this": True, + "alias": True, + "scalar": False, + "materialized": False, + "key_expressions": False, + } + + +class ProjectionDef(Expression): + arg_types = {"this": True, "expression": True} + + +class TableAlias(Expression): + arg_types = {"this": False, "columns": False} + + @property + def columns(self): + return self.args.get("columns") or [] + + +class BitString(Condition): + pass + + +class HexString(Condition): + arg_types = {"this": True, "is_integer": False} + + +class ByteString(Condition): + arg_types = {"this": True, "is_bytes": False} + + +class RawString(Condition): + pass + + +class UnicodeString(Condition): + arg_types = {"this": True, "escape": False} + + +class Column(Condition): + arg_types = { + "this": True, + "table": False, + "db": False, + "catalog": False, + "join_mark": False, + } + + @property + def table(self) -> str: + return self.text("table") + + @property + def db(self) -> str: + return self.text("db") + + @property + def catalog(self) -> str: + return self.text("catalog") + + @property + def output_name(self) -> str: + return self.name + + @property + def parts(self) -> t.List[Identifier]: + """Return the parts of a column in order catalog, db, table, name.""" + return [ + t.cast(Identifier, self.args[part]) + for part in ("catalog", "db", "table", "this") + if self.args.get(part) + ] + + def to_dot(self, include_dots: bool = True) -> Dot | Identifier: + """Converts the column into a dot expression.""" + parts = self.parts + parent = self.parent + + if include_dots: + while isinstance(parent, Dot): + parts.append(parent.expression) + parent = parent.parent + + return Dot.build(deepcopy(parts)) if len(parts) > 1 else parts[0] + + +class Pseudocolumn(Column): + pass + + +class ColumnPosition(Expression): + arg_types = {"this": False, "position": True} + + +class ColumnDef(Expression): + arg_types = { + "this": True, + "kind": False, + "constraints": False, + "exists": False, + "position": False, + "default": False, + "output": False, + } + + @property + def constraints(self) -> t.List[ColumnConstraint]: + return self.args.get("constraints") or [] + + @property + def kind(self) -> t.Optional[DataType]: + return self.args.get("kind") + + +class AlterColumn(Expression): + arg_types = { + "this": True, + "dtype": False, + "collate": False, + "using": False, + "default": False, + "drop": False, + "comment": False, + "allow_null": False, + "visible": False, + "rename_to": False, + } + + +# https://dev.mysql.com/doc/refman/8.0/en/invisible-indexes.html +class AlterIndex(Expression): + arg_types = {"this": True, "visible": True} + + +# https://docs.aws.amazon.com/redshift/latest/dg/r_ALTER_TABLE.html +class AlterDistStyle(Expression): + pass + + +class AlterSortKey(Expression): + arg_types = {"this": False, "expressions": False, "compound": False} + + +class AlterSet(Expression): + arg_types = { + "expressions": False, + "option": False, + "tablespace": False, + "access_method": False, + "file_format": False, + "copy_options": False, + "tag": False, + "location": False, + "serde": False, + } + + +class RenameColumn(Expression): + arg_types = {"this": True, "to": True, "exists": False} + + +class AlterRename(Expression): + pass + + +class SwapTable(Expression): + pass + + +class Comment(Expression): + arg_types = { + "this": True, + "kind": True, + "expression": True, + "exists": False, + "materialized": False, + } + + +class Comprehension(Expression): + arg_types = { + "this": True, + "expression": True, + "position": False, + "iterator": True, + "condition": False, + } + + +# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl +class MergeTreeTTLAction(Expression): + arg_types = { + "this": True, + "delete": False, + "recompress": False, + "to_disk": False, + "to_volume": False, + } + + +# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl +class MergeTreeTTL(Expression): + arg_types = { + "expressions": True, + "where": False, + "group": False, + "aggregates": False, + } + + +# https://dev.mysql.com/doc/refman/8.0/en/create-table.html +class IndexConstraintOption(Expression): + arg_types = { + "key_block_size": False, + "using": False, + "parser": False, + "comment": False, + "visible": False, + "engine_attr": False, + "secondary_engine_attr": False, + } + + +class ColumnConstraint(Expression): + arg_types = {"this": False, "kind": True} + + @property + def kind(self) -> ColumnConstraintKind: + return self.args["kind"] + + +class ColumnConstraintKind(Expression): + pass + + +class AutoIncrementColumnConstraint(ColumnConstraintKind): + pass + + +class ZeroFillColumnConstraint(ColumnConstraint): + arg_types = {} + + +class PeriodForSystemTimeConstraint(ColumnConstraintKind): + arg_types = {"this": True, "expression": True} + + +class CaseSpecificColumnConstraint(ColumnConstraintKind): + arg_types = {"not_": True} + + +class CharacterSetColumnConstraint(ColumnConstraintKind): + arg_types = {"this": True} + + +class CheckColumnConstraint(ColumnConstraintKind): + arg_types = {"this": True, "enforced": False} + + +class ClusteredColumnConstraint(ColumnConstraintKind): + pass + + +class CollateColumnConstraint(ColumnConstraintKind): + pass + + +class CommentColumnConstraint(ColumnConstraintKind): + pass + + +class CompressColumnConstraint(ColumnConstraintKind): + arg_types = {"this": False} + + +class DateFormatColumnConstraint(ColumnConstraintKind): + arg_types = {"this": True} + + +class DefaultColumnConstraint(ColumnConstraintKind): + pass + + +class EncodeColumnConstraint(ColumnConstraintKind): + pass + + +# https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-EXCLUDE +class ExcludeColumnConstraint(ColumnConstraintKind): + pass + + +class EphemeralColumnConstraint(ColumnConstraintKind): + arg_types = {"this": False} + + +class WithOperator(Expression): + arg_types = {"this": True, "op": True} + + +class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): + # this: True -> ALWAYS, this: False -> BY DEFAULT + arg_types = { + "this": False, + "expression": False, + "on_null": False, + "start": False, + "increment": False, + "minvalue": False, + "maxvalue": False, + "cycle": False, + "order": False, + } + + +class GeneratedAsRowColumnConstraint(ColumnConstraintKind): + arg_types = {"start": False, "hidden": False} + + +# https://dev.mysql.com/doc/refman/8.0/en/create-table.html +# https://github.com/ClickHouse/ClickHouse/blob/master/src/Parsers/ParserCreateQuery.h#L646 +class IndexColumnConstraint(ColumnConstraintKind): + arg_types = { + "this": False, + "expressions": False, + "kind": False, + "index_type": False, + "options": False, + "expression": False, # Clickhouse + "granularity": False, + } + + +class InlineLengthColumnConstraint(ColumnConstraintKind): + pass + + +class NonClusteredColumnConstraint(ColumnConstraintKind): + pass + + +class NotForReplicationColumnConstraint(ColumnConstraintKind): + arg_types = {} + + +# https://docs.snowflake.com/en/sql-reference/sql/create-table +class MaskingPolicyColumnConstraint(ColumnConstraintKind): + arg_types = {"this": True, "expressions": False} + + +class NotNullColumnConstraint(ColumnConstraintKind): + arg_types = {"allow_null": False} + + +# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html +class OnUpdateColumnConstraint(ColumnConstraintKind): + pass + + +class PrimaryKeyColumnConstraint(ColumnConstraintKind): + arg_types = {"desc": False, "options": False} + + +class TitleColumnConstraint(ColumnConstraintKind): + pass + + +class UniqueColumnConstraint(ColumnConstraintKind): + arg_types = { + "this": False, + "index_type": False, + "on_conflict": False, + "nulls": False, + "options": False, + } + + +class UppercaseColumnConstraint(ColumnConstraintKind): + arg_types: t.Dict[str, t.Any] = {} + + +# https://docs.risingwave.com/processing/watermarks#syntax +class WatermarkColumnConstraint(Expression): + arg_types = {"this": True, "expression": True} + + +class PathColumnConstraint(ColumnConstraintKind): + pass + + +# https://docs.snowflake.com/en/sql-reference/sql/create-table +class ProjectionPolicyColumnConstraint(ColumnConstraintKind): + pass + + +# computed column expression +# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16 +class ComputedColumnConstraint(ColumnConstraintKind): + arg_types = { + "this": True, + "persisted": False, + "not_null": False, + "data_type": False, + } + + +class Constraint(Expression): + arg_types = {"this": True, "expressions": True} + + +class Delete(DML): + arg_types = { + "with_": False, + "this": False, + "using": False, + "where": False, + "returning": False, + "order": False, + "limit": False, + "tables": False, # Multiple-Table Syntax (MySQL) + "cluster": False, # Clickhouse + } + + def delete( + self, + table: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Delete: + """ + Create a DELETE expression or replace the table on an existing DELETE expression. + + Example: + >>> delete("tbl").sql() + 'DELETE FROM tbl' + + Args: + table: the table from which to delete. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_builder( + expression=table, + instance=self, + arg="this", + dialect=dialect, + into=Table, + copy=copy, + **opts, + ) + + def where( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Delete: + """ + Append to or set the WHERE expressions. + + Example: + >>> delete("tbl").where("x = 'a' OR x < 'b'").sql() + "DELETE FROM tbl WHERE x = 'a' OR x < 'b'" + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append: if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="where", + append=append, + into=Where, + dialect=dialect, + copy=copy, + **opts, + ) + + +class Drop(Expression): + arg_types = { + "this": False, + "kind": False, + "expressions": False, + "exists": False, + "temporary": False, + "materialized": False, + "cascade": False, + "constraints": False, + "purge": False, + "cluster": False, + "concurrently": False, + } + + @property + def kind(self) -> t.Optional[str]: + kind = self.args.get("kind") + return kind and kind.upper() + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/export-statements +class Export(Expression): + arg_types = {"this": True, "connection": False, "options": True} + + +class Filter(Expression): + arg_types = {"this": True, "expression": True} + + +class Check(Expression): + pass + + +class Changes(Expression): + arg_types = {"information": True, "at_before": False, "end": False} + + +# https://docs.snowflake.com/en/sql-reference/constructs/connect-by +class Connect(Expression): + arg_types = {"start": False, "connect": True, "nocycle": False} + + +class CopyParameter(Expression): + arg_types = {"this": True, "expression": False, "expressions": False} + + +class Copy(DML): + arg_types = { + "this": True, + "kind": True, + "files": False, + "credentials": False, + "format": False, + "params": False, + } + + +class Credentials(Expression): + arg_types = { + "credentials": False, + "encryption": False, + "storage": False, + "iam_role": False, + "region": False, + } + + +class Prior(Expression): + pass + + +class Directory(Expression): + arg_types = {"this": True, "local": False, "row_format": False} + + +# https://docs.snowflake.com/en/user-guide/data-load-dirtables-query +class DirectoryStage(Expression): + pass + + +class ForeignKey(Expression): + arg_types = { + "expressions": False, + "reference": False, + "delete": False, + "update": False, + "options": False, + } + + +class ColumnPrefix(Expression): + arg_types = {"this": True, "expression": True} + + +class PrimaryKey(Expression): + arg_types = {"this": False, "expressions": True, "options": False, "include": False} + + +# https://www.postgresql.org/docs/9.1/sql-selectinto.html +# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples +class Into(Expression): + arg_types = { + "this": False, + "temporary": False, + "unlogged": False, + "bulk_collect": False, + "expressions": False, + } + + +class From(Expression): + @property + def name(self) -> str: + return self.this.name + + @property + def alias_or_name(self) -> str: + return self.this.alias_or_name + + +class Having(Expression): + pass + + +class Hint(Expression): + arg_types = {"expressions": True} + + +class JoinHint(Expression): + arg_types = {"this": True, "expressions": True} + + +class Identifier(Expression): + arg_types = {"this": True, "quoted": False, "global_": False, "temporary": False} + + @property + def quoted(self) -> bool: + return bool(self.args.get("quoted")) + + @property + def output_name(self) -> str: + return self.name + + +# https://www.postgresql.org/docs/current/indexes-opclass.html +class Opclass(Expression): + arg_types = {"this": True, "expression": True} + + +class Index(Expression): + arg_types = { + "this": False, + "table": False, + "unique": False, + "primary": False, + "amp": False, # teradata + "params": False, + } + + +class IndexParameters(Expression): + arg_types = { + "using": False, + "include": False, + "columns": False, + "with_storage": False, + "partition_by": False, + "tablespace": False, + "where": False, + "on": False, + } + + +class Insert(DDL, DML): + arg_types = { + "hint": False, + "with_": False, + "is_function": False, + "this": False, + "expression": False, + "conflict": False, + "returning": False, + "overwrite": False, + "exists": False, + "alternative": False, + "where": False, + "ignore": False, + "by_name": False, + "stored": False, + "partition": False, + "settings": False, + "source": False, + "default": False, + } + + def with_( + self, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + materialized: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Insert: + """ + Append to or set the common table expressions. + + Example: + >>> insert("SELECT x FROM cte", "t").with_("cte", as_="SELECT * FROM tbl").sql() + 'WITH cte AS (SELECT * FROM tbl) INSERT INTO t SELECT x FROM cte' + + Args: + alias: the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_: the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + materialized: set the MATERIALIZED part of the expression. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_cte_builder( + self, + alias, + as_, + recursive=recursive, + materialized=materialized, + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + +class ConditionalInsert(Expression): + arg_types = {"this": True, "expression": False, "else_": False} + + +class MultitableInserts(Expression): + arg_types = {"expressions": True, "kind": True, "source": True} + + +class OnConflict(Expression): + arg_types = { + "duplicate": False, + "expressions": False, + "action": False, + "conflict_keys": False, + "constraint": False, + "where": False, + } + + +class OnCondition(Expression): + arg_types = {"error": False, "empty": False, "null": False} + + +class Returning(Expression): + arg_types = {"expressions": True, "into": False} + + +# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html +class Introducer(Expression): + arg_types = {"this": True, "expression": True} + + +# national char, like n'utf8' +class National(Expression): + pass + + +class LoadData(Expression): + arg_types = { + "this": True, + "local": False, + "overwrite": False, + "inpath": True, + "partition": False, + "input_format": False, + "serde": False, + } + + +class Partition(Expression): + arg_types = {"expressions": True, "subpartition": False} + + +class PartitionRange(Expression): + arg_types = {"this": True, "expression": False, "expressions": False} + + +# https://clickhouse.com/docs/en/sql-reference/statements/alter/partition#how-to-set-partition-expression +class PartitionId(Expression): + pass + + +class Fetch(Expression): + arg_types = { + "direction": False, + "count": False, + "limit_options": False, + } + + +class Grant(Expression): + arg_types = { + "privileges": True, + "kind": False, + "securable": True, + "principals": True, + "grant_option": False, + } + + +class Revoke(Expression): + arg_types = {**Grant.arg_types, "cascade": False} + + +class Group(Expression): + arg_types = { + "expressions": False, + "grouping_sets": False, + "cube": False, + "rollup": False, + "totals": False, + "all": False, + } + + +class Cube(Expression): + arg_types = {"expressions": False} + + +class Rollup(Expression): + arg_types = {"expressions": False} + + +class GroupingSets(Expression): + arg_types = {"expressions": True} + + +class Lambda(Expression): + arg_types = {"this": True, "expressions": True, "colon": False} + + +class Limit(Expression): + arg_types = { + "this": False, + "expression": True, + "offset": False, + "limit_options": False, + "expressions": False, + } + + +class LimitOptions(Expression): + arg_types = { + "percent": False, + "rows": False, + "with_ties": False, + } + + +class Literal(Condition): + arg_types = {"this": True, "is_string": True} + + @classmethod + def number(cls, number) -> Literal: + return cls(this=str(number), is_string=False) + + @classmethod + def string(cls, string) -> Literal: + return cls(this=str(string), is_string=True) + + @property + def output_name(self) -> str: + return self.name + + def to_py(self) -> int | str | Decimal: + if self.is_number: + try: + return int(self.this) + except ValueError: + return Decimal(self.this) + return self.this + + +class Join(Expression): + arg_types = { + "this": True, + "on": False, + "side": False, + "kind": False, + "using": False, + "method": False, + "global_": False, + "hint": False, + "match_condition": False, # Snowflake + "expressions": False, + "pivots": False, + } + + @property + def method(self) -> str: + return self.text("method").upper() + + @property + def kind(self) -> str: + return self.text("kind").upper() + + @property + def side(self) -> str: + return self.text("side").upper() + + @property + def hint(self) -> str: + return self.text("hint").upper() + + @property + def alias_or_name(self) -> str: + return self.this.alias_or_name + + @property + def is_semi_or_anti_join(self) -> bool: + return self.kind in ("SEMI", "ANTI") + + def on( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Join: + """ + Append to or set the ON expressions. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("JOIN x", into=Join).on("y = 1").sql() + 'JOIN x ON y = 1' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append: if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Join expression. + """ + join = _apply_conjunction_builder( + *expressions, + instance=self, + arg="on", + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + if join.kind == "CROSS": + join.set("kind", None) + + return join + + def using( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Join: + """ + Append to or set the USING expressions. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("JOIN x", into=Join).using("foo", "bla").sql() + 'JOIN x USING (foo, bla)' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append: if `True`, concatenate the new expressions to the existing "using" list. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Join expression. + """ + join = _apply_list_builder( + *expressions, + instance=self, + arg="using", + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + if join.kind == "CROSS": + join.set("kind", None) + + return join + + +class Lateral(UDTF): + arg_types = { + "this": True, + "view": False, + "outer": False, + "alias": False, + "cross_apply": False, # True -> CROSS APPLY, False -> OUTER APPLY + "ordinality": False, + } + + +# https://docs.snowflake.com/sql-reference/literals-table +# https://docs.snowflake.com/en/sql-reference/functions-table#using-a-table-function +class TableFromRows(UDTF): + arg_types = { + "this": True, + "alias": False, + "joins": False, + "pivots": False, + "sample": False, + } + + +class MatchRecognizeMeasure(Expression): + arg_types = { + "this": True, + "window_frame": False, + } + + +class MatchRecognize(Expression): + arg_types = { + "partition_by": False, + "order": False, + "measures": False, + "rows": False, + "after": False, + "pattern": False, + "define": False, + "alias": False, + } + + +# Clickhouse FROM FINAL modifier +# https://clickhouse.com/docs/en/sql-reference/statements/select/from/#final-modifier +class Final(Expression): + pass + + +class Offset(Expression): + arg_types = {"this": False, "expression": True, "expressions": False} + + +class Order(Expression): + arg_types = {"this": False, "expressions": True, "siblings": False} + + +# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier +class WithFill(Expression): + arg_types = { + "from_": False, + "to": False, + "step": False, + "interpolate": False, + } + + +# hive specific sorts +# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+SortBy +class Cluster(Order): + pass + + +class Distribute(Order): + pass + + +class Sort(Order): + pass + + +class Ordered(Expression): + arg_types = {"this": True, "desc": False, "nulls_first": True, "with_fill": False} + + @property + def name(self) -> str: + return self.this.name + + +class Property(Expression): + arg_types = {"this": True, "value": True} + + +class GrantPrivilege(Expression): + arg_types = {"this": True, "expressions": False} + + +class GrantPrincipal(Expression): + arg_types = {"this": True, "kind": False} + + +class AllowedValuesProperty(Expression): + arg_types = {"expressions": True} + + +class AlgorithmProperty(Property): + arg_types = {"this": True} + + +class AutoIncrementProperty(Property): + arg_types = {"this": True} + + +# https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html +class AutoRefreshProperty(Property): + arg_types = {"this": True} + + +class BackupProperty(Property): + arg_types = {"this": True} + + +# https://doris.apache.org/docs/sql-manual/sql-statements/table-and-view/async-materialized-view/CREATE-ASYNC-MATERIALIZED-VIEW/ +class BuildProperty(Property): + arg_types = {"this": True} + + +class BlockCompressionProperty(Property): + arg_types = { + "autotemp": False, + "always": False, + "default": False, + "manual": False, + "never": False, + } + + +class CharacterSetProperty(Property): + arg_types = {"this": True, "default": True} + + +class ChecksumProperty(Property): + arg_types = {"on": False, "default": False} + + +class CollateProperty(Property): + arg_types = {"this": True, "default": False} + + +class CopyGrantsProperty(Property): + arg_types = {} + + +class DataBlocksizeProperty(Property): + arg_types = { + "size": False, + "units": False, + "minimum": False, + "maximum": False, + "default": False, + } + + +class DataDeletionProperty(Property): + arg_types = {"on": True, "filter_column": False, "retention_period": False} + + +class DefinerProperty(Property): + arg_types = {"this": True} + + +class DistKeyProperty(Property): + arg_types = {"this": True} + + +# https://docs.starrocks.io/docs/sql-reference/sql-statements/data-definition/CREATE_TABLE/#distribution_desc +# https://doris.apache.org/docs/sql-manual/sql-statements/Data-Definition-Statements/Create/CREATE-TABLE?_highlight=create&_highlight=table#distribution_desc +class DistributedByProperty(Property): + arg_types = {"expressions": False, "kind": True, "buckets": False, "order": False} + + +class DistStyleProperty(Property): + arg_types = {"this": True} + + +class DuplicateKeyProperty(Property): + arg_types = {"expressions": True} + + +class EngineProperty(Property): + arg_types = {"this": True} + + +class HeapProperty(Property): + arg_types = {} + + +class ToTableProperty(Property): + arg_types = {"this": True} + + +class ExecuteAsProperty(Property): + arg_types = {"this": True} + + +class ExternalProperty(Property): + arg_types = {"this": False} + + +class FallbackProperty(Property): + arg_types = {"no": True, "protection": False} + + +# https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-syntax-ddl-create-table-hiveformat +class FileFormatProperty(Property): + arg_types = {"this": False, "expressions": False, "hive_format": False} + + +class CredentialsProperty(Property): + arg_types = {"expressions": True} + + +class FreespaceProperty(Property): + arg_types = {"this": True, "percent": False} + + +class GlobalProperty(Property): + arg_types = {} + + +class IcebergProperty(Property): + arg_types = {} + + +class InheritsProperty(Property): + arg_types = {"expressions": True} + + +class InputModelProperty(Property): + arg_types = {"this": True} + + +class OutputModelProperty(Property): + arg_types = {"this": True} + + +class IsolatedLoadingProperty(Property): + arg_types = {"no": False, "concurrent": False, "target": False} + + +class JournalProperty(Property): + arg_types = { + "no": False, + "dual": False, + "before": False, + "local": False, + "after": False, + } + + +class LanguageProperty(Property): + arg_types = {"this": True} + + +class EnviromentProperty(Property): + arg_types = {"expressions": True} + + +# spark ddl +class ClusteredByProperty(Property): + arg_types = {"expressions": True, "sorted_by": False, "buckets": True} + + +class DictProperty(Property): + arg_types = {"this": True, "kind": True, "settings": False} + + +class DictSubProperty(Property): + pass + + +class DictRange(Property): + arg_types = {"this": True, "min": True, "max": True} + + +class DynamicProperty(Property): + arg_types = {} + + +# Clickhouse CREATE ... ON CLUSTER modifier +# https://clickhouse.com/docs/en/sql-reference/distributed-ddl +class OnCluster(Property): + arg_types = {"this": True} + + +# Clickhouse EMPTY table "property" +class EmptyProperty(Property): + arg_types = {} + + +class LikeProperty(Property): + arg_types = {"this": True, "expressions": False} + + +class LocationProperty(Property): + arg_types = {"this": True} + + +class LockProperty(Property): + arg_types = {"this": True} + + +class LockingProperty(Property): + arg_types = { + "this": False, + "kind": True, + "for_or_in": False, + "lock_type": True, + "override": False, + } + + +class LogProperty(Property): + arg_types = {"no": True} + + +class MaterializedProperty(Property): + arg_types = {"this": False} + + +class MergeBlockRatioProperty(Property): + arg_types = {"this": False, "no": False, "default": False, "percent": False} + + +class NoPrimaryIndexProperty(Property): + arg_types = {} + + +class OnProperty(Property): + arg_types = {"this": True} + + +class OnCommitProperty(Property): + arg_types = {"delete": False} + + +class PartitionedByProperty(Property): + arg_types = {"this": True} + + +class PartitionedByBucket(Property): + arg_types = {"this": True, "expression": True} + + +class PartitionByTruncate(Property): + arg_types = {"this": True, "expression": True} + + +# https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/ +class PartitionByRangeProperty(Property): + arg_types = {"partition_expressions": True, "create_expressions": True} + + +# https://docs.starrocks.io/docs/table_design/data_distribution/#range-partitioning +class PartitionByRangePropertyDynamic(Expression): + arg_types = {"this": False, "start": True, "end": True, "every": True} + + +# https://doris.apache.org/docs/table-design/data-partitioning/manual-partitioning +class PartitionByListProperty(Property): + arg_types = {"partition_expressions": True, "create_expressions": True} + + +# https://doris.apache.org/docs/table-design/data-partitioning/manual-partitioning +class PartitionList(Expression): + arg_types = {"this": True, "expressions": True} + + +# https://doris.apache.org/docs/sql-manual/sql-statements/table-and-view/async-materialized-view/CREATE-ASYNC-MATERIALIZED-VIEW +class RefreshTriggerProperty(Property): + arg_types = { + "method": True, + "kind": False, + "every": False, + "unit": False, + "starts": False, + } + + +# https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/ +class UniqueKeyProperty(Property): + arg_types = {"expressions": True} + + +# https://www.postgresql.org/docs/current/sql-createtable.html +class PartitionBoundSpec(Expression): + # this -> IN / MODULUS, expression -> REMAINDER, from_expressions -> FROM (...), to_expressions -> TO (...) + arg_types = { + "this": False, + "expression": False, + "from_expressions": False, + "to_expressions": False, + } + + +class PartitionedOfProperty(Property): + # this -> parent_table (schema), expression -> FOR VALUES ... / DEFAULT + arg_types = {"this": True, "expression": True} + + +class StreamingTableProperty(Property): + arg_types = {} + + +class RemoteWithConnectionModelProperty(Property): + arg_types = {"this": True} + + +class ReturnsProperty(Property): + arg_types = {"this": False, "is_table": False, "table": False, "null": False} + + +class StrictProperty(Property): + arg_types = {} + + +class RowFormatProperty(Property): + arg_types = {"this": True} + + +class RowFormatDelimitedProperty(Property): + # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml + arg_types = { + "fields": False, + "escaped": False, + "collection_items": False, + "map_keys": False, + "lines": False, + "null": False, + "serde": False, + } + + +class RowFormatSerdeProperty(Property): + arg_types = {"this": True, "serde_properties": False} + + +# https://spark.apache.org/docs/3.1.2/sql-ref-syntax-qry-select-transform.html +class QueryTransform(Expression): + arg_types = { + "expressions": True, + "command_script": True, + "schema": False, + "row_format_before": False, + "record_writer": False, + "row_format_after": False, + "record_reader": False, + } + + +class SampleProperty(Property): + arg_types = {"this": True} + + +# https://prestodb.io/docs/current/sql/create-view.html#synopsis +class SecurityProperty(Property): + arg_types = {"this": True} + + +class SchemaCommentProperty(Property): + arg_types = {"this": True} + + +class SemanticView(Expression): + arg_types = { + "this": True, + "metrics": False, + "dimensions": False, + "facts": False, + "where": False, + } + + +class SerdeProperties(Property): + arg_types = {"expressions": True, "with_": False} + + +class SetProperty(Property): + arg_types = {"multi": True} + + +class SharingProperty(Property): + arg_types = {"this": False} + + +class SetConfigProperty(Property): + arg_types = {"this": True} + + +class SettingsProperty(Property): + arg_types = {"expressions": True} + + +class SortKeyProperty(Property): + arg_types = {"this": True, "compound": False} + + +class SqlReadWriteProperty(Property): + arg_types = {"this": True} + + +class SqlSecurityProperty(Property): + arg_types = {"this": True} + + +class StabilityProperty(Property): + arg_types = {"this": True} + + +class StorageHandlerProperty(Property): + arg_types = {"this": True} + + +class TemporaryProperty(Property): + arg_types = {"this": False} + + +class SecureProperty(Property): + arg_types = {} + + +# https://docs.snowflake.com/en/sql-reference/sql/create-table +class Tags(ColumnConstraintKind, Property): + arg_types = {"expressions": True} + + +class TransformModelProperty(Property): + arg_types = {"expressions": True} + + +class TransientProperty(Property): + arg_types = {"this": False} + + +class UnloggedProperty(Property): + arg_types = {} + + +# https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-using-template +class UsingTemplateProperty(Property): + arg_types = {"this": True} + + +# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-view-transact-sql?view=sql-server-ver16 +class ViewAttributeProperty(Property): + arg_types = {"this": True} + + +class VolatileProperty(Property): + arg_types = {"this": False} + + +class WithDataProperty(Property): + arg_types = {"no": True, "statistics": False} + + +class WithJournalTableProperty(Property): + arg_types = {"this": True} + + +class WithSchemaBindingProperty(Property): + arg_types = {"this": True} + + +class WithSystemVersioningProperty(Property): + arg_types = { + "on": False, + "this": False, + "data_consistency": False, + "retention_period": False, + "with_": True, + } + + +class WithProcedureOptions(Property): + arg_types = {"expressions": True} + + +class EncodeProperty(Property): + arg_types = {"this": True, "properties": False, "key": False} + + +class IncludeProperty(Property): + arg_types = {"this": True, "alias": False, "column_def": False} + + +class ForceProperty(Property): + arg_types = {} + + +class Properties(Expression): + arg_types = {"expressions": True} + + NAME_TO_PROPERTY = { + "ALGORITHM": AlgorithmProperty, + "AUTO_INCREMENT": AutoIncrementProperty, + "CHARACTER SET": CharacterSetProperty, + "CLUSTERED_BY": ClusteredByProperty, + "COLLATE": CollateProperty, + "COMMENT": SchemaCommentProperty, + "CREDENTIALS": CredentialsProperty, + "DEFINER": DefinerProperty, + "DISTKEY": DistKeyProperty, + "DISTRIBUTED_BY": DistributedByProperty, + "DISTSTYLE": DistStyleProperty, + "ENGINE": EngineProperty, + "EXECUTE AS": ExecuteAsProperty, + "FORMAT": FileFormatProperty, + "LANGUAGE": LanguageProperty, + "LOCATION": LocationProperty, + "LOCK": LockProperty, + "PARTITIONED_BY": PartitionedByProperty, + "RETURNS": ReturnsProperty, + "ROW_FORMAT": RowFormatProperty, + "SORTKEY": SortKeyProperty, + "ENCODE": EncodeProperty, + "INCLUDE": IncludeProperty, + } + + PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} + + # CREATE property locations + # Form: schema specified + # create [POST_CREATE] + # table a [POST_NAME] + # (b int) [POST_SCHEMA] + # with ([POST_WITH]) + # index (b) [POST_INDEX] + # + # Form: alias selection + # create [POST_CREATE] + # table a [POST_NAME] + # as [POST_ALIAS] (select * from b) [POST_EXPRESSION] + # index (c) [POST_INDEX] + class Location(AutoName): + POST_CREATE = auto() + POST_NAME = auto() + POST_SCHEMA = auto() + POST_WITH = auto() + POST_ALIAS = auto() + POST_EXPRESSION = auto() + POST_INDEX = auto() + UNSUPPORTED = auto() + + @classmethod + def from_dict(cls, properties_dict: t.Dict) -> Properties: + expressions = [] + for key, value in properties_dict.items(): + property_cls = cls.NAME_TO_PROPERTY.get(key.upper()) + if property_cls: + expressions.append(property_cls(this=convert(value))) + else: + expressions.append( + Property(this=Literal.string(key), value=convert(value)) + ) + + return cls(expressions=expressions) + + +class Qualify(Expression): + pass + + +class InputOutputFormat(Expression): + arg_types = {"input_format": False, "output_format": False} + + +# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql +class Return(Expression): + pass + + +class Reference(Expression): + arg_types = {"this": True, "expressions": False, "options": False} + + +class Tuple(Expression): + arg_types = {"expressions": False} + + def isin( + self, + *expressions: t.Any, + query: t.Optional[ExpOrStr] = None, + unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None, + copy: bool = True, + **opts, + ) -> In: + return In( + this=maybe_copy(self, copy), + expressions=[convert(e, copy=copy) for e in expressions], + query=maybe_parse(query, copy=copy, **opts) if query else None, + unnest=( + Unnest( + expressions=[ + maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) + for e in ensure_list(unnest) + ] + ) + if unnest + else None + ), + ) + + +QUERY_MODIFIERS = { + "match": False, + "laterals": False, + "joins": False, + "connect": False, + "pivots": False, + "prewhere": False, + "where": False, + "group": False, + "having": False, + "qualify": False, + "windows": False, + "distribute": False, + "sort": False, + "cluster": False, + "order": False, + "limit": False, + "offset": False, + "locks": False, + "sample": False, + "settings": False, + "format": False, + "options": False, +} + + +# https://learn.microsoft.com/en-us/sql/t-sql/queries/option-clause-transact-sql?view=sql-server-ver16 +# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-query?view=sql-server-ver16 +class QueryOption(Expression): + arg_types = {"this": True, "expression": False} + + +# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16 +class WithTableHint(Expression): + arg_types = {"expressions": True} + + +# https://dev.mysql.com/doc/refman/8.0/en/index-hints.html +class IndexTableHint(Expression): + arg_types = {"this": True, "expressions": False, "target": False} + + +# https://docs.snowflake.com/en/sql-reference/constructs/at-before +class HistoricalData(Expression): + arg_types = {"this": True, "kind": True, "expression": True} + + +# https://docs.snowflake.com/en/sql-reference/sql/put +class Put(Expression): + arg_types = {"this": True, "target": True, "properties": False} + + +# https://docs.snowflake.com/en/sql-reference/sql/get +class Get(Expression): + arg_types = {"this": True, "target": True, "properties": False} + + +class Table(Expression): + arg_types = { + "this": False, + "alias": False, + "db": False, + "catalog": False, + "laterals": False, + "joins": False, + "pivots": False, + "hints": False, + "system_time": False, + "version": False, + "format": False, + "pattern": False, + "ordinality": False, + "when": False, + "only": False, + "partition": False, + "changes": False, + "rows_from": False, + "sample": False, + "indexed": False, + } + + @property + def name(self) -> str: + if not self.this or isinstance(self.this, Func): + return "" + return self.this.name + + @property + def db(self) -> str: + return self.text("db") + + @property + def catalog(self) -> str: + return self.text("catalog") + + @property + def selects(self) -> t.List[Expression]: + return [] + + @property + def named_selects(self) -> t.List[str]: + return [] + + @property + def parts(self) -> t.List[Expression]: + """Return the parts of a table in order catalog, db, table.""" + parts: t.List[Expression] = [] + + for arg in ("catalog", "db", "this"): + part = self.args.get(arg) + + if isinstance(part, Dot): + parts.extend(part.flatten()) + elif isinstance(part, Expression): + parts.append(part) + + return parts + + def to_column(self, copy: bool = True) -> Expression: + parts = self.parts + last_part = parts[-1] + + if isinstance(last_part, Identifier): + col: Expression = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore + else: + # This branch will be reached if a function or array is wrapped in a `Table` + col = last_part + + alias = self.args.get("alias") + if alias: + col = alias_(col, alias.this, copy=copy) + + return col + + +class SetOperation(Query): + arg_types = { + "with_": False, + "this": True, + "expression": True, + "distinct": False, + "by_name": False, + "side": False, + "kind": False, + "on": False, + **QUERY_MODIFIERS, + } + + def select( + self: S, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> S: + this = maybe_copy(self, copy) + this.this.unnest().select( + *expressions, append=append, dialect=dialect, copy=False, **opts + ) + this.expression.unnest().select( + *expressions, append=append, dialect=dialect, copy=False, **opts + ) + return this + + @property + def named_selects(self) -> t.List[str]: + expression = self + while isinstance(expression, SetOperation): + expression = expression.this.unnest() + return expression.named_selects + + @property + def is_star(self) -> bool: + return self.this.is_star or self.expression.is_star + + @property + def selects(self) -> t.List[Expression]: + expression = self + while isinstance(expression, SetOperation): + expression = expression.this.unnest() + return expression.selects + + @property + def left(self) -> Query: + return self.this + + @property + def right(self) -> Query: + return self.expression + + @property + def kind(self) -> str: + return self.text("kind").upper() + + @property + def side(self) -> str: + return self.text("side").upper() + + +class Union(SetOperation): + pass + + +class Except(SetOperation): + pass + + +class Intersect(SetOperation): + pass + + +class Update(DML): + arg_types = { + "with_": False, + "this": False, + "expressions": False, + "from_": False, + "where": False, + "returning": False, + "order": False, + "limit": False, + "options": False, + } + + def table( + self, + expression: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Update: + """ + Set the table to update. + + Example: + >>> Update().table("my_table").set_("x = 1").sql() + 'UPDATE my_table SET x = 1' + + Args: + expression : the SQL code strings to parse. + If a `Table` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Table`. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Update expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="this", + into=Table, + prefix=None, + dialect=dialect, + copy=copy, + **opts, + ) + + def set_( + self, + *expressions: ExpOrStr, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Update: + """ + Append to or set the SET expressions. + + Example: + >>> Update().table("my_table").set_("x = 1").sql() + 'UPDATE my_table SET x = 1' + + Args: + *expressions: the SQL code strings to parse. + If `Expression` instance(s) are passed, they will be used as-is. + Multiple expressions are combined with a comma. + append: if `True`, add the new expressions to any existing SET expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + """ + return _apply_list_builder( + *expressions, + instance=self, + arg="expressions", + append=append, + into=Expression, + prefix=None, + dialect=dialect, + copy=copy, + **opts, + ) + + def where( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Append to or set the WHERE expressions. + + Example: + >>> Update().table("tbl").set_("x = 1").where("x = 'a' OR x < 'b'").sql() + "UPDATE tbl SET x = 1 WHERE x = 'a' OR x < 'b'" + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append: if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="where", + append=append, + into=Where, + dialect=dialect, + copy=copy, + **opts, + ) + + def from_( + self, + expression: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Update: + """ + Set the FROM expression. + + Example: + >>> Update().table("my_table").set_("x = 1").from_("baz").sql() + 'UPDATE my_table SET x = 1 FROM baz' + + Args: + expression : the SQL code strings to parse. + If a `From` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `From`. + If nothing is passed in then a from is not applied to the expression + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Update expression. + """ + if not expression: + return maybe_copy(self, copy) + + return _apply_builder( + expression=expression, + instance=self, + arg="from_", + into=From, + prefix="FROM", + dialect=dialect, + copy=copy, + **opts, + ) + + def with_( + self, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + materialized: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Update: + """ + Append to or set the common table expressions. + + Example: + >>> Update().table("my_table").set_("x = 1").from_("baz").with_("baz", "SELECT id FROM foo").sql() + 'WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz' + + Args: + alias: the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_: the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + materialized: set the MATERIALIZED part of the expression. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_cte_builder( + self, + alias, + as_, + recursive=recursive, + materialized=materialized, + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + +# DuckDB supports VALUES followed by https://duckdb.org/docs/stable/sql/query_syntax/limit +class Values(UDTF): + arg_types = { + "expressions": True, + "alias": False, + "order": False, + "limit": False, + "offset": False, + } + + +class Var(Expression): + pass + + +class Version(Expression): + """ + Time travel, iceberg, bigquery etc + https://trino.io/docs/current/connector/iceberg.html?highlight=snapshot#using-snapshots + https://www.databricks.com/blog/2019/02/04/introducing-delta-time-travel-for-large-scale-data-lakes.html + https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#for_system_time_as_of + https://learn.microsoft.com/en-us/sql/relational-databases/tables/querying-data-in-a-system-versioned-temporal-table?view=sql-server-ver16 + this is either TIMESTAMP or VERSION + kind is ("AS OF", "BETWEEN") + """ + + arg_types = {"this": True, "kind": True, "expression": False} + + +class Schema(Expression): + arg_types = {"this": False, "expressions": False} + + +# https://dev.mysql.com/doc/refman/8.0/en/select.html +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/SELECT.html +class Lock(Expression): + arg_types = {"update": True, "expressions": False, "wait": False, "key": False} + + +class Select(Query): + arg_types = { + "with_": False, + "kind": False, + "expressions": False, + "hint": False, + "distinct": False, + "into": False, + "from_": False, + "operation_modifiers": False, + **QUERY_MODIFIERS, + } + + def from_( + self, + expression: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Set the FROM expression. + + Example: + >>> Select().from_("tbl").select("x").sql() + 'SELECT x FROM tbl' + + Args: + expression : the SQL code strings to parse. + If a `From` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `From`. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="from_", + into=From, + prefix="FROM", + dialect=dialect, + copy=copy, + **opts, + ) + + def group_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Set the GROUP BY expression. + + Example: + >>> Select().from_("tbl").select("x", "COUNT(1)").group_by("x").sql() + 'SELECT x, COUNT(1) FROM tbl GROUP BY x' + + Args: + *expressions: the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Group`. + If nothing is passed in then a group by is not applied to the expression + append: if `True`, add to any existing expressions. + Otherwise, this flattens all the `Group` expression into a single expression. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + if not expressions: + return self if not copy else self.copy() + + return _apply_child_list_builder( + *expressions, + instance=self, + arg="group", + append=append, + copy=copy, + prefix="GROUP BY", + into=Group, + dialect=dialect, + **opts, + ) + + def sort_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Set the SORT BY expression. + + Example: + >>> Select().from_("tbl").select("x").sort_by("x DESC").sql(dialect="hive") + 'SELECT x FROM tbl SORT BY x DESC' + + Args: + *expressions: the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `SORT`. + append: if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="sort", + append=append, + copy=copy, + prefix="SORT BY", + into=Sort, + dialect=dialect, + **opts, + ) + + def cluster_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Set the CLUSTER BY expression. + + Example: + >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql(dialect="hive") + 'SELECT x FROM tbl CLUSTER BY x DESC' + + Args: + *expressions: the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Cluster`. + append: if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="cluster", + append=append, + copy=copy, + prefix="CLUSTER BY", + into=Cluster, + dialect=dialect, + **opts, + ) + + def select( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + return _apply_list_builder( + *expressions, + instance=self, + arg="expressions", + append=append, + dialect=dialect, + into=Expression, + copy=copy, + **opts, + ) + + def lateral( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Append to or set the LATERAL expressions. + + Example: + >>> Select().select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl").sql() + 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_list_builder( + *expressions, + instance=self, + arg="laterals", + append=append, + into=Lateral, + prefix="LATERAL VIEW", + dialect=dialect, + copy=copy, + **opts, + ) + + def join( + self, + expression: ExpOrStr, + on: t.Optional[ExpOrStr] = None, + using: t.Optional[ExpOrStr | t.Collection[ExpOrStr]] = None, + append: bool = True, + join_type: t.Optional[str] = None, + join_alias: t.Optional[Identifier | str] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Append to or set the JOIN expressions. + + Example: + >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y").sql() + 'SELECT * FROM tbl JOIN tbl2 ON tbl1.y = tbl2.y' + + >>> Select().select("1").from_("a").join("b", using=["x", "y", "z"]).sql() + 'SELECT 1 FROM a JOIN b USING (x, y, z)' + + Use `join_type` to change the type of join: + + >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y", join_type="left outer").sql() + 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y' + + Args: + expression: the SQL code string to parse. + If an `Expression` instance is passed, it will be used as-is. + on: optionally specify the join "on" criteria as a SQL string. + If an `Expression` instance is passed, it will be used as-is. + using: optionally specify the join "using" criteria as a SQL string. + If an `Expression` instance is passed, it will be used as-is. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + join_type: if set, alter the parsed join type. + join_alias: an optional alias for the joined source. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + parse_args: t.Dict[str, t.Any] = {"dialect": dialect, **opts} + + try: + expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) + except ParseError: + expression = maybe_parse(expression, into=(Join, Expression), **parse_args) + + join = expression if isinstance(expression, Join) else Join(this=expression) + + if isinstance(join.this, Select): + join.this.replace(join.this.subquery()) + + if join_type: + method: t.Optional[Token] + side: t.Optional[Token] + kind: t.Optional[Token] + + method, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore + + if method: + join.set("method", method.text) + if side: + join.set("side", side.text) + if kind: + join.set("kind", kind.text) + + if on: + on = and_(*ensure_list(on), dialect=dialect, copy=copy, **opts) + join.set("on", on) + + if using: + join = _apply_list_builder( + *ensure_list(using), + instance=join, + arg="using", + append=append, + copy=copy, + into=Identifier, + **opts, + ) + + if join_alias: + join.set("this", alias_(join.this, join_alias, table=True)) + + return _apply_list_builder( + join, + instance=self, + arg="joins", + append=append, + copy=copy, + **opts, + ) + + def having( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Append to or set the HAVING expressions. + + Example: + >>> Select().select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 3").sql() + 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append: if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="having", + append=append, + into=Having, + dialect=dialect, + copy=copy, + **opts, + ) + + def window( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + return _apply_list_builder( + *expressions, + instance=self, + arg="windows", + append=append, + into=Window, + dialect=dialect, + copy=copy, + **opts, + ) + + def qualify( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="qualify", + append=append, + into=Qualify, + dialect=dialect, + copy=copy, + **opts, + ) + + def distinct( + self, *ons: t.Optional[ExpOrStr], distinct: bool = True, copy: bool = True + ) -> Select: + """ + Set the OFFSET expression. + + Example: + >>> Select().from_("tbl").select("x").distinct().sql() + 'SELECT DISTINCT x FROM tbl' + + Args: + ons: the expressions to distinct on + distinct: whether the Select should be distinct + copy: if `False`, modify this expression instance in-place. + + Returns: + Select: the modified expression. + """ + instance = maybe_copy(self, copy) + on = ( + Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) + if ons + else None + ) + instance.set("distinct", Distinct(on=on) if distinct else None) + return instance + + def ctas( + self, + table: ExpOrStr, + properties: t.Optional[t.Dict] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Create: + """ + Convert this expression to a CREATE TABLE AS statement. + + Example: + >>> Select().select("*").from_("tbl").ctas("x").sql() + 'CREATE TABLE x AS SELECT * FROM tbl' + + Args: + table: the SQL code string to parse as the table name. + If another `Expression` instance is passed, it will be used as-is. + properties: an optional mapping of table properties + dialect: the dialect used to parse the input table. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input table. + + Returns: + The new Create expression. + """ + instance = maybe_copy(self, copy) + table_expression = maybe_parse(table, into=Table, dialect=dialect, **opts) + + properties_expression = None + if properties: + properties_expression = Properties.from_dict(properties) + + return Create( + this=table_expression, + kind="TABLE", + expression=instance, + properties=properties_expression, + ) + + def lock(self, update: bool = True, copy: bool = True) -> Select: + """ + Set the locking read mode for this expression. + + Examples: + >>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql") + "SELECT x FROM tbl WHERE x = 'a' FOR UPDATE" + + >>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql") + "SELECT x FROM tbl WHERE x = 'a' FOR SHARE" + + Args: + update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`. + copy: if `False`, modify this expression instance in-place. + + Returns: + The modified expression. + """ + inst = maybe_copy(self, copy) + inst.set("locks", [Lock(update=update)]) + + return inst + + def hint( + self, *hints: ExpOrStr, dialect: DialectType = None, copy: bool = True + ) -> Select: + """ + Set hints for this expression. + + Examples: + >>> Select().select("x").from_("tbl").hint("BROADCAST(y)").sql(dialect="spark") + 'SELECT /*+ BROADCAST(y) */ x FROM tbl' + + Args: + hints: The SQL code strings to parse as the hints. + If an `Expression` instance is passed, it will be used as-is. + dialect: The dialect used to parse the hints. + copy: If `False`, modify this expression instance in-place. + + Returns: + The modified expression. + """ + inst = maybe_copy(self, copy) + inst.set( + "hint", + Hint( + expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints] + ), + ) + + return inst + + @property + def named_selects(self) -> t.List[str]: + selects = [] + + for e in self.expressions: + if e.alias_or_name: + selects.append(e.output_name) + elif isinstance(e, Aliases): + selects.extend([a.name for a in e.aliases]) + return selects + + @property + def is_star(self) -> bool: + return any(expression.is_star for expression in self.expressions) + + @property + def selects(self) -> t.List[Expression]: + return self.expressions + + +UNWRAPPED_QUERIES = (Select, SetOperation) + + +class Subquery(DerivedTable, Query): + arg_types = { + "this": True, + "alias": False, + "with_": False, + **QUERY_MODIFIERS, + } + + def unnest(self): + """Returns the first non subquery.""" + expression = self + while isinstance(expression, Subquery): + expression = expression.this + return expression + + def unwrap(self) -> Subquery: + expression = self + while expression.same_parent and expression.is_wrapper: + expression = t.cast(Subquery, expression.parent) + return expression + + def select( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Subquery: + this = maybe_copy(self, copy) + this.unnest().select( + *expressions, append=append, dialect=dialect, copy=False, **opts + ) + return this + + @property + def is_wrapper(self) -> bool: + """ + Whether this Subquery acts as a simple wrapper around another expression. + + SELECT * FROM (((SELECT * FROM t))) + ^ + This corresponds to a "wrapper" Subquery node + """ + return all(v is None for k, v in self.args.items() if k != "this") + + @property + def is_star(self) -> bool: + return self.this.is_star + + @property + def output_name(self) -> str: + return self.alias + + +class TableSample(Expression): + arg_types = { + "expressions": False, + "method": False, + "bucket_numerator": False, + "bucket_denominator": False, + "bucket_field": False, + "percent": False, + "rows": False, + "size": False, + "seed": False, + } + + +class Tag(Expression): + """Tags are used for generating arbitrary sql like SELECT x.""" + + arg_types = { + "this": False, + "prefix": False, + "postfix": False, + } + + +# Represents both the standard SQL PIVOT operator and DuckDB's "simplified" PIVOT syntax +# https://duckdb.org/docs/sql/statements/pivot +class Pivot(Expression): + arg_types = { + "this": False, + "alias": False, + "expressions": False, + "fields": False, + "unpivot": False, + "using": False, + "group": False, + "columns": False, + "include_nulls": False, + "default_on_null": False, + "into": False, + "with_": False, + } + + @property + def unpivot(self) -> bool: + return bool(self.args.get("unpivot")) + + @property + def fields(self) -> t.List[Expression]: + return self.args.get("fields", []) + + +# https://duckdb.org/docs/sql/statements/unpivot#simplified-unpivot-syntax +# UNPIVOT ... INTO [NAME VALUE ][...,] +class UnpivotColumns(Expression): + arg_types = {"this": True, "expressions": True} + + +class Window(Condition): + arg_types = { + "this": True, + "partition_by": False, + "order": False, + "spec": False, + "alias": False, + "over": False, + "first": False, + } + + +class WindowSpec(Expression): + arg_types = { + "kind": False, + "start": False, + "start_side": False, + "end": False, + "end_side": False, + "exclude": False, + } + + +class PreWhere(Expression): + pass + + +class Where(Expression): + pass + + +class Star(Expression): + arg_types = {"except_": False, "replace": False, "rename": False} + + @property + def name(self) -> str: + return "*" + + @property + def output_name(self) -> str: + return self.name + + +class Parameter(Condition): + arg_types = {"this": True, "expression": False} + + +class SessionParameter(Condition): + arg_types = {"this": True, "kind": False} + + +# https://www.databricks.com/blog/parameterized-queries-pyspark +# https://jdbc.postgresql.org/documentation/query/#using-the-statement-or-preparedstatement-interface +class Placeholder(Condition): + arg_types = {"this": False, "kind": False, "widget": False, "jdbc": False} + + @property + def name(self) -> str: + return self.this or "?" + + +class Null(Condition): + arg_types: t.Dict[str, t.Any] = {} + + @property + def name(self) -> str: + return "NULL" + + def to_py(self) -> Lit[None]: + return None + + +class Boolean(Condition): + def to_py(self) -> bool: + return self.this + + +class DataTypeParam(Expression): + arg_types = {"this": True, "expression": False} + + @property + def name(self) -> str: + return self.this.name + + +# The `nullable` arg is helpful when transpiling types from other dialects to ClickHouse, which +# assumes non-nullable types by default. Values `None` and `True` mean the type is nullable. +class DataType(Expression): + arg_types = { + "this": True, + "expressions": False, + "nested": False, + "values": False, + "prefix": False, + "kind": False, + "nullable": False, + } + + class Type(AutoName): + ARRAY = auto() + AGGREGATEFUNCTION = auto() + SIMPLEAGGREGATEFUNCTION = auto() + BIGDECIMAL = auto() + BIGINT = auto() + BIGNUM = auto() + BIGSERIAL = auto() + BINARY = auto() + BIT = auto() + BLOB = auto() + BOOLEAN = auto() + BPCHAR = auto() + CHAR = auto() + DATE = auto() + DATE32 = auto() + DATEMULTIRANGE = auto() + DATERANGE = auto() + DATETIME = auto() + DATETIME2 = auto() + DATETIME64 = auto() + DECIMAL = auto() + DECIMAL32 = auto() + DECIMAL64 = auto() + DECIMAL128 = auto() + DECIMAL256 = auto() + DECFLOAT = auto() + DOUBLE = auto() + DYNAMIC = auto() + ENUM = auto() + ENUM8 = auto() + ENUM16 = auto() + FILE = auto() + FIXEDSTRING = auto() + FLOAT = auto() + GEOGRAPHY = auto() + GEOGRAPHYPOINT = auto() + GEOMETRY = auto() + POINT = auto() + RING = auto() + LINESTRING = auto() + MULTILINESTRING = auto() + POLYGON = auto() + MULTIPOLYGON = auto() + HLLSKETCH = auto() + HSTORE = auto() + IMAGE = auto() + INET = auto() + INT = auto() + INT128 = auto() + INT256 = auto() + INT4MULTIRANGE = auto() + INT4RANGE = auto() + INT8MULTIRANGE = auto() + INT8RANGE = auto() + INTERVAL = auto() + IPADDRESS = auto() + IPPREFIX = auto() + IPV4 = auto() + IPV6 = auto() + JSON = auto() + JSONB = auto() + LIST = auto() + LONGBLOB = auto() + LONGTEXT = auto() + LOWCARDINALITY = auto() + MAP = auto() + MEDIUMBLOB = auto() + MEDIUMINT = auto() + MEDIUMTEXT = auto() + MONEY = auto() + NAME = auto() + NCHAR = auto() + NESTED = auto() + NOTHING = auto() + NULL = auto() + NUMMULTIRANGE = auto() + NUMRANGE = auto() + NVARCHAR = auto() + OBJECT = auto() + RANGE = auto() + ROWVERSION = auto() + SERIAL = auto() + SET = auto() + SMALLDATETIME = auto() + SMALLINT = auto() + SMALLMONEY = auto() + SMALLSERIAL = auto() + STRUCT = auto() + SUPER = auto() + TEXT = auto() + TINYBLOB = auto() + TINYTEXT = auto() + TIME = auto() + TIMETZ = auto() + TIME_NS = auto() + TIMESTAMP = auto() + TIMESTAMPNTZ = auto() + TIMESTAMPLTZ = auto() + TIMESTAMPTZ = auto() + TIMESTAMP_S = auto() + TIMESTAMP_MS = auto() + TIMESTAMP_NS = auto() + TINYINT = auto() + TSMULTIRANGE = auto() + TSRANGE = auto() + TSTZMULTIRANGE = auto() + TSTZRANGE = auto() + UBIGINT = auto() + UINT = auto() + UINT128 = auto() + UINT256 = auto() + UMEDIUMINT = auto() + UDECIMAL = auto() + UDOUBLE = auto() + UNION = auto() + UNKNOWN = auto() # Sentinel value, useful for type annotation + USERDEFINED = "USER-DEFINED" + USMALLINT = auto() + UTINYINT = auto() + UUID = auto() + VARBINARY = auto() + VARCHAR = auto() + VARIANT = auto() + VECTOR = auto() + XML = auto() + YEAR = auto() + TDIGEST = auto() + + STRUCT_TYPES = { + Type.FILE, + Type.NESTED, + Type.OBJECT, + Type.STRUCT, + Type.UNION, + } + + ARRAY_TYPES = { + Type.ARRAY, + Type.LIST, + } + + NESTED_TYPES = { + *STRUCT_TYPES, + *ARRAY_TYPES, + Type.MAP, + } + + TEXT_TYPES = { + Type.CHAR, + Type.NCHAR, + Type.NVARCHAR, + Type.TEXT, + Type.VARCHAR, + Type.NAME, + } + + SIGNED_INTEGER_TYPES = { + Type.BIGINT, + Type.INT, + Type.INT128, + Type.INT256, + Type.MEDIUMINT, + Type.SMALLINT, + Type.TINYINT, + } + + UNSIGNED_INTEGER_TYPES = { + Type.UBIGINT, + Type.UINT, + Type.UINT128, + Type.UINT256, + Type.UMEDIUMINT, + Type.USMALLINT, + Type.UTINYINT, + } + + INTEGER_TYPES = { + *SIGNED_INTEGER_TYPES, + *UNSIGNED_INTEGER_TYPES, + Type.BIT, + } + + FLOAT_TYPES = { + Type.DOUBLE, + Type.FLOAT, + } + + REAL_TYPES = { + *FLOAT_TYPES, + Type.BIGDECIMAL, + Type.DECIMAL, + Type.DECIMAL32, + Type.DECIMAL64, + Type.DECIMAL128, + Type.DECIMAL256, + Type.DECFLOAT, + Type.MONEY, + Type.SMALLMONEY, + Type.UDECIMAL, + Type.UDOUBLE, + } + + NUMERIC_TYPES = { + *INTEGER_TYPES, + *REAL_TYPES, + } + + TEMPORAL_TYPES = { + Type.DATE, + Type.DATE32, + Type.DATETIME, + Type.DATETIME2, + Type.DATETIME64, + Type.SMALLDATETIME, + Type.TIME, + Type.TIMESTAMP, + Type.TIMESTAMPNTZ, + Type.TIMESTAMPLTZ, + Type.TIMESTAMPTZ, + Type.TIMESTAMP_MS, + Type.TIMESTAMP_NS, + Type.TIMESTAMP_S, + Type.TIMETZ, + } + + @classmethod + def build( + cls, + dtype: DATA_TYPE, + dialect: DialectType = None, + udt: bool = False, + copy: bool = True, + **kwargs, + ) -> DataType: + """ + Constructs a DataType object. + + Args: + dtype: the data type of interest. + dialect: the dialect to use for parsing `dtype`, in case it's a string. + udt: when set to True, `dtype` will be used as-is if it can't be parsed into a + DataType, thus creating a user-defined type. + copy: whether to copy the data type. + kwargs: additional arguments to pass in the constructor of DataType. + + Returns: + The constructed DataType object. + """ + from bigframes_vendored.sqlglot import parse_one + + if isinstance(dtype, str): + if dtype.upper() == "UNKNOWN": + return DataType(this=DataType.Type.UNKNOWN, **kwargs) + + try: + data_type_exp = parse_one( + dtype, read=dialect, into=DataType, error_level=ErrorLevel.IGNORE + ) + except ParseError: + if udt: + return DataType( + this=DataType.Type.USERDEFINED, kind=dtype, **kwargs + ) + raise + elif isinstance(dtype, (Identifier, Dot)) and udt: + return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs) + elif isinstance(dtype, DataType.Type): + data_type_exp = DataType(this=dtype) + elif isinstance(dtype, DataType): + return maybe_copy(dtype, copy) + else: + raise ValueError( + f"Invalid data type: {type(dtype)}. Expected str or DataType.Type" + ) + + return DataType(**{**data_type_exp.args, **kwargs}) + + def is_type(self, *dtypes: DATA_TYPE, check_nullable: bool = False) -> bool: + """ + Checks whether this DataType matches one of the provided data types. Nested types or precision + will be compared using "structural equivalence" semantics, so e.g. array != array. + + Args: + dtypes: the data types to compare this DataType to. + check_nullable: whether to take the NULLABLE type constructor into account for the comparison. + If false, it means that NULLABLE is equivalent to INT. + + Returns: + True, if and only if there is a type in `dtypes` which is equal to this DataType. + """ + self_is_nullable = self.args.get("nullable") + for dtype in dtypes: + other_type = DataType.build(dtype, copy=False, udt=True) + other_is_nullable = other_type.args.get("nullable") + if ( + other_type.expressions + or (check_nullable and (self_is_nullable or other_is_nullable)) + or self.this == DataType.Type.USERDEFINED + or other_type.this == DataType.Type.USERDEFINED + ): + matches = self == other_type + else: + matches = self.this == other_type.this + + if matches: + return True + return False + + +# https://www.postgresql.org/docs/15/datatype-pseudo.html +class PseudoType(DataType): + arg_types = {"this": True} + + +# https://www.postgresql.org/docs/15/datatype-oid.html +class ObjectIdentifier(DataType): + arg_types = {"this": True} + + +# WHERE x EXISTS|ALL|ANY|SOME(SELECT ...) +class SubqueryPredicate(Predicate): + pass + + +class All(SubqueryPredicate): + pass + + +class Any(SubqueryPredicate): + pass + + +# Commands to interact with the databases or engines. For most of the command +# expressions we parse whatever comes after the command's name as a string. +class Command(Expression): + arg_types = {"this": True, "expression": False} + + +class Transaction(Expression): + arg_types = {"this": False, "modes": False, "mark": False} + + +class Commit(Expression): + arg_types = {"chain": False, "this": False, "durability": False} + + +class Rollback(Expression): + arg_types = {"savepoint": False, "this": False} + + +class Alter(Expression): + arg_types = { + "this": False, + "kind": True, + "actions": True, + "exists": False, + "only": False, + "options": False, + "cluster": False, + "not_valid": False, + "check": False, + "cascade": False, + } + + @property + def kind(self) -> t.Optional[str]: + kind = self.args.get("kind") + return kind and kind.upper() + + @property + def actions(self) -> t.List[Expression]: + return self.args.get("actions") or [] + + +class AlterSession(Expression): + arg_types = {"expressions": True, "unset": False} + + +class Analyze(Expression): + arg_types = { + "kind": False, + "this": False, + "options": False, + "mode": False, + "partition": False, + "expression": False, + "properties": False, + } + + +class AnalyzeStatistics(Expression): + arg_types = { + "kind": True, + "option": False, + "this": False, + "expressions": False, + } + + +class AnalyzeHistogram(Expression): + arg_types = { + "this": True, + "expressions": True, + "expression": False, + "update_options": False, + } + + +class AnalyzeSample(Expression): + arg_types = {"kind": True, "sample": True} + + +class AnalyzeListChainedRows(Expression): + arg_types = {"expression": False} + + +class AnalyzeDelete(Expression): + arg_types = {"kind": False} + + +class AnalyzeWith(Expression): + arg_types = {"expressions": True} + + +class AnalyzeValidate(Expression): + arg_types = { + "kind": True, + "this": False, + "expression": False, + } + + +class AnalyzeColumns(Expression): + pass + + +class UsingData(Expression): + pass + + +class AddConstraint(Expression): + arg_types = {"expressions": True} + + +class AddPartition(Expression): + arg_types = {"this": True, "exists": False, "location": False} + + +class AttachOption(Expression): + arg_types = {"this": True, "expression": False} + + +class DropPartition(Expression): + arg_types = {"expressions": True, "exists": False} + + +# https://clickhouse.com/docs/en/sql-reference/statements/alter/partition#replace-partition +class ReplacePartition(Expression): + arg_types = {"expression": True, "source": True} + + +# Binary expressions like (ADD a b) +class Binary(Condition): + arg_types = {"this": True, "expression": True} + + @property + def left(self) -> Expression: + return self.this + + @property + def right(self) -> Expression: + return self.expression + + +class Add(Binary): + pass + + +class Connector(Binary): + pass + + +class BitwiseAnd(Binary): + arg_types = {"this": True, "expression": True, "padside": False} + + +class BitwiseLeftShift(Binary): + pass + + +class BitwiseOr(Binary): + arg_types = {"this": True, "expression": True, "padside": False} + + +class BitwiseRightShift(Binary): + pass + + +class BitwiseXor(Binary): + arg_types = {"this": True, "expression": True, "padside": False} + + +class Div(Binary): + arg_types = {"this": True, "expression": True, "typed": False, "safe": False} + + +class Overlaps(Binary): + pass + + +class ExtendsLeft(Binary): + pass + + +class ExtendsRight(Binary): + pass + + +class Dot(Binary): + @property + def is_star(self) -> bool: + return self.expression.is_star + + @property + def name(self) -> str: + return self.expression.name + + @property + def output_name(self) -> str: + return self.name + + @classmethod + def build(self, expressions: t.Sequence[Expression]) -> Dot: + """Build a Dot object with a sequence of expressions.""" + if len(expressions) < 2: + raise ValueError("Dot requires >= 2 expressions.") + + return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions)) + + @property + def parts(self) -> t.List[Expression]: + """Return the parts of a table / column in order catalog, db, table.""" + this, *parts = self.flatten() + + parts.reverse() + + for arg in COLUMN_PARTS: + part = this.args.get(arg) + + if isinstance(part, Expression): + parts.append(part) + + parts.reverse() + return parts + + +DATA_TYPE = t.Union[str, Identifier, Dot, DataType, DataType.Type] + + +class DPipe(Binary): + arg_types = {"this": True, "expression": True, "safe": False} + + +class EQ(Binary, Predicate): + pass + + +class NullSafeEQ(Binary, Predicate): + pass + + +class NullSafeNEQ(Binary, Predicate): + pass + + +# Represents e.g. := in DuckDB which is mostly used for setting parameters +class PropertyEQ(Binary): + pass + + +class Distance(Binary): + pass + + +class Escape(Binary): + pass + + +class Glob(Binary, Predicate): + pass + + +class GT(Binary, Predicate): + pass + + +class GTE(Binary, Predicate): + pass + + +class ILike(Binary, Predicate): + pass + + +class IntDiv(Binary): + pass + + +class Is(Binary, Predicate): + pass + + +class Kwarg(Binary): + """Kwarg in special functions like func(kwarg => y).""" + + +class Like(Binary, Predicate): + pass + + +class Match(Binary, Predicate): + pass + + +class LT(Binary, Predicate): + pass + + +class LTE(Binary, Predicate): + pass + + +class Mod(Binary): + pass + + +class Mul(Binary): + pass + + +class NEQ(Binary, Predicate): + pass + + +# https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH +class Operator(Binary): + arg_types = {"this": True, "operator": True, "expression": True} + + +class SimilarTo(Binary, Predicate): + pass + + +class Sub(Binary): + pass + + +# https://www.postgresql.org/docs/current/functions-range.html +# Represents range adjacency operator: -|- +class Adjacent(Binary): + pass + + +# Unary Expressions +# (NOT a) +class Unary(Condition): + pass + + +class BitwiseNot(Unary): + pass + + +class Not(Unary): + pass + + +class Paren(Unary): + @property + def output_name(self) -> str: + return self.this.name + + +class Neg(Unary): + def to_py(self) -> int | Decimal: + if self.is_number: + return self.this.to_py() * -1 + return super().to_py() + + +class Alias(Expression): + arg_types = {"this": True, "alias": False} + + @property + def output_name(self) -> str: + return self.alias + + +# BigQuery requires the UNPIVOT column list aliases to be either strings or ints, but +# other dialects require identifiers. This enables us to transpile between them easily. +class PivotAlias(Alias): + pass + + +# Represents Snowflake's ANY [ ORDER BY ... ] syntax +# https://docs.snowflake.com/en/sql-reference/constructs/pivot +class PivotAny(Expression): + arg_types = {"this": False} + + +class Aliases(Expression): + arg_types = {"this": True, "expressions": True} + + @property + def aliases(self): + return self.expressions + + +# https://docs.aws.amazon.com/redshift/latest/dg/query-super.html +class AtIndex(Expression): + arg_types = {"this": True, "expression": True} + + +class AtTimeZone(Expression): + arg_types = {"this": True, "zone": True} + + +class FromTimeZone(Expression): + arg_types = {"this": True, "zone": True} + + +class FormatPhrase(Expression): + """Format override for a column in Teradata. + Can be expanded to additional dialects as needed + + https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Data-Types-and-Literals/Data-Type-Formats-and-Format-Phrases/FORMAT + """ + + arg_types = {"this": True, "format": True} + + +class Between(Predicate): + arg_types = {"this": True, "low": True, "high": True, "symmetric": False} + + +class Bracket(Condition): + # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator + arg_types = { + "this": True, + "expressions": True, + "offset": False, + "safe": False, + "returns_list_for_maps": False, + } + + @property + def output_name(self) -> str: + if len(self.expressions) == 1: + return self.expressions[0].output_name + + return super().output_name + + +class Distinct(Expression): + arg_types = {"expressions": False, "on": False} + + +class In(Predicate): + arg_types = { + "this": True, + "expressions": False, + "query": False, + "unnest": False, + "field": False, + "is_global": False, + } + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#for-in +class ForIn(Expression): + arg_types = {"this": True, "expression": True} + + +class TimeUnit(Expression): + """Automatically converts unit arg into a var.""" + + arg_types = {"unit": False} + + UNABBREVIATED_UNIT_NAME = { + "D": "DAY", + "H": "HOUR", + "M": "MINUTE", + "MS": "MILLISECOND", + "NS": "NANOSECOND", + "Q": "QUARTER", + "S": "SECOND", + "US": "MICROSECOND", + "W": "WEEK", + "Y": "YEAR", + } + + VAR_LIKE = (Column, Literal, Var) + + def __init__(self, **args): + unit = args.get("unit") + if type(unit) in self.VAR_LIKE and not ( + isinstance(unit, Column) and len(unit.parts) != 1 + ): + args["unit"] = Var( + this=(self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper() + ) + elif isinstance(unit, Week): + unit.set("this", Var(this=unit.this.name.upper())) + + super().__init__(**args) + + @property + def unit(self) -> t.Optional[Var | IntervalSpan]: + return self.args.get("unit") + + +class IntervalOp(TimeUnit): + arg_types = {"unit": False, "expression": True} + + def interval(self): + return Interval( + this=self.expression.copy(), + unit=self.unit.copy() if self.unit else None, + ) + + +# https://www.oracletutorial.com/oracle-basics/oracle-interval/ +# https://trino.io/docs/current/language/types.html#interval-day-to-second +# https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html +class IntervalSpan(DataType): + arg_types = {"this": True, "expression": True} + + +class Interval(TimeUnit): + arg_types = {"this": False, "unit": False} + + +class IgnoreNulls(Expression): + pass + + +class RespectNulls(Expression): + pass + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/aggregate-function-calls#max_min_clause +class HavingMax(Expression): + arg_types = {"this": True, "expression": True, "max": True} + + +# Functions +class Func(Condition): + """ + The base class for all function expressions. + + Attributes: + is_var_len_args (bool): if set to True the last argument defined in arg_types will be + treated as a variable length argument and the argument's value will be stored as a list. + _sql_names (list): the SQL name (1st item in the list) and aliases (subsequent items) for this + function expression. These values are used to map this node to a name during parsing as + well as to provide the function's name during SQL string generation. By default the SQL + name is set to the expression's class name transformed to snake case. + """ + + is_var_len_args = False + + @classmethod + def from_arg_list(cls, args): + if cls.is_var_len_args: + all_arg_keys = list(cls.arg_types) + # If this function supports variable length argument treat the last argument as such. + non_var_len_arg_keys = ( + all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys + ) + num_non_var = len(non_var_len_arg_keys) + + args_dict = { + arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys) + } + args_dict[all_arg_keys[-1]] = args[num_non_var:] + else: + args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)} + + return cls(**args_dict) + + @classmethod + def sql_names(cls): + if cls is Func: + raise NotImplementedError( + "SQL name is only supported by concrete function implementations" + ) + if "_sql_names" not in cls.__dict__: + cls._sql_names = [camel_to_snake_case(cls.__name__)] + return cls._sql_names + + @classmethod + def sql_name(cls): + sql_names = cls.sql_names() + assert sql_names, f"Expected non-empty 'sql_names' for Func: {cls.__name__}." + return sql_names[0] + + @classmethod + def default_parser_mappings(cls): + return {name: cls.from_arg_list for name in cls.sql_names()} + + +class Typeof(Func): + pass + + +class Acos(Func): + pass + + +class Acosh(Func): + pass + + +class Asin(Func): + pass + + +class Asinh(Func): + pass + + +class Atan(Func): + arg_types = {"this": True, "expression": False} + + +class Atanh(Func): + pass + + +class Atan2(Func): + arg_types = {"this": True, "expression": True} + + +class Cot(Func): + pass + + +class Coth(Func): + pass + + +class Cos(Func): + pass + + +class Csc(Func): + pass + + +class Csch(Func): + pass + + +class Sec(Func): + pass + + +class Sech(Func): + pass + + +class Sin(Func): + pass + + +class Sinh(Func): + pass + + +class Tan(Func): + pass + + +class Tanh(Func): + pass + + +class Degrees(Func): + pass + + +class Cosh(Func): + pass + + +class CosineDistance(Func): + arg_types = {"this": True, "expression": True} + + +class DotProduct(Func): + arg_types = {"this": True, "expression": True} + + +class EuclideanDistance(Func): + arg_types = {"this": True, "expression": True} + + +class ManhattanDistance(Func): + arg_types = {"this": True, "expression": True} + + +class JarowinklerSimilarity(Func): + arg_types = {"this": True, "expression": True} + + +class AggFunc(Func): + pass + + +class BitwiseAndAgg(AggFunc): + pass + + +class BitwiseOrAgg(AggFunc): + pass + + +class BitwiseXorAgg(AggFunc): + pass + + +class BoolxorAgg(AggFunc): + pass + + +class BitwiseCount(Func): + pass + + +class BitmapBucketNumber(Func): + pass + + +class BitmapCount(Func): + pass + + +class BitmapBitPosition(Func): + pass + + +class BitmapConstructAgg(AggFunc): + pass + + +class BitmapOrAgg(AggFunc): + pass + + +class ByteLength(Func): + pass + + +class Boolnot(Func): + pass + + +class Booland(Func): + arg_types = {"this": True, "expression": True} + + +class Boolor(Func): + arg_types = {"this": True, "expression": True} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#bool_for_json +class JSONBool(Func): + pass + + +class ArrayRemove(Func): + arg_types = {"this": True, "expression": True} + + +class ParameterizedAgg(AggFunc): + arg_types = {"this": True, "expressions": True, "params": True} + + +class Abs(Func): + pass + + +class ArgMax(AggFunc): + arg_types = {"this": True, "expression": True, "count": False} + _sql_names = ["ARG_MAX", "ARGMAX", "MAX_BY"] + + +class ArgMin(AggFunc): + arg_types = {"this": True, "expression": True, "count": False} + _sql_names = ["ARG_MIN", "ARGMIN", "MIN_BY"] + + +class ApproxTopK(AggFunc): + arg_types = {"this": True, "expression": False, "counters": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/approx_top_k_accumulate +# https://spark.apache.org/docs/preview/api/sql/index.html#approx_top_k_accumulate +class ApproxTopKAccumulate(AggFunc): + arg_types = {"this": True, "expression": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/approx_top_k_combine +class ApproxTopKCombine(AggFunc): + arg_types = {"this": True, "expression": False} + + +class ApproxTopKEstimate(Func): + arg_types = {"this": True, "expression": False} + + +class ApproxTopSum(AggFunc): + arg_types = {"this": True, "expression": True, "count": True} + + +class ApproxQuantiles(AggFunc): + arg_types = {"this": True, "expression": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/approx_percentile_combine +class ApproxPercentileCombine(AggFunc): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/minhash +class Minhash(AggFunc): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +# https://docs.snowflake.com/en/sql-reference/functions/minhash_combine +class MinhashCombine(AggFunc): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/approximate_similarity +class ApproximateSimilarity(AggFunc): + _sql_names = ["APPROXIMATE_SIMILARITY", "APPROXIMATE_JACCARD_INDEX"] + + +class FarmFingerprint(Func): + arg_types = {"expressions": True} + is_var_len_args = True + _sql_names = ["FARM_FINGERPRINT", "FARMFINGERPRINT64"] + + +class Flatten(Func): + pass + + +class Float64(Func): + arg_types = {"this": True, "expression": False} + + +# https://spark.apache.org/docs/latest/api/sql/index.html#transform +class Transform(Func): + arg_types = {"this": True, "expression": True} + + +class Translate(Func): + arg_types = {"this": True, "from_": True, "to": True} + + +class Grouping(AggFunc): + arg_types = {"expressions": True} + is_var_len_args = True + + +class GroupingId(AggFunc): + arg_types = {"expressions": True} + is_var_len_args = True + + +class Anonymous(Func): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + @property + def name(self) -> str: + return self.this if isinstance(self.this, str) else self.this.name + + +class AnonymousAggFunc(AggFunc): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/combinators +class CombinedAggFunc(AnonymousAggFunc): + arg_types = {"this": True, "expressions": False} + + +class CombinedParameterizedAgg(ParameterizedAgg): + arg_types = {"this": True, "expressions": True, "params": True} + + +# https://docs.snowflake.com/en/sql-reference/functions/hash_agg +class HashAgg(AggFunc): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +# https://docs.snowflake.com/en/sql-reference/functions/hll +# https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html +class Hll(AggFunc): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class ApproxDistinct(AggFunc): + arg_types = {"this": True, "accuracy": False} + _sql_names = ["APPROX_DISTINCT", "APPROX_COUNT_DISTINCT"] + + +class Apply(Func): + arg_types = {"this": True, "expression": True} + + +class Array(Func): + arg_types = { + "expressions": False, + "bracket_notation": False, + "struct_name_inheritance": False, + } + is_var_len_args = True + + +class Ascii(Func): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/to_array +class ToArray(Func): + pass + + +class ToBoolean(Func): + arg_types = {"this": True, "safe": False} + + +# https://materialize.com/docs/sql/types/list/ +class List(Func): + arg_types = {"expressions": False} + is_var_len_args = True + + +# String pad, kind True -> LPAD, False -> RPAD +class Pad(Func): + arg_types = { + "this": True, + "expression": True, + "fill_pattern": False, + "is_left": True, + } + + +# https://docs.snowflake.com/en/sql-reference/functions/to_char +# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html +class ToChar(Func): + arg_types = { + "this": True, + "format": False, + "nlsparam": False, + "is_numeric": False, + } + + +class ToCodePoints(Func): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/to_decimal +# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_NUMBER.html +class ToNumber(Func): + arg_types = { + "this": True, + "format": False, + "nlsparam": False, + "precision": False, + "scale": False, + "safe": False, + "safe_name": False, + } + + +# https://docs.snowflake.com/en/sql-reference/functions/to_double +class ToDouble(Func): + arg_types = { + "this": True, + "format": False, + "safe": False, + } + + +# https://docs.snowflake.com/en/sql-reference/functions/to_decfloat +class ToDecfloat(Func): + arg_types = { + "this": True, + "format": False, + } + + +# https://docs.snowflake.com/en/sql-reference/functions/try_to_decfloat +class TryToDecfloat(Func): + arg_types = { + "this": True, + "format": False, + } + + +# https://docs.snowflake.com/en/sql-reference/functions/to_file +class ToFile(Func): + arg_types = { + "this": True, + "path": False, + "safe": False, + } + + +class CodePointsToBytes(Func): + pass + + +class Columns(Func): + arg_types = {"this": True, "unpack": False} + + +# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax +class Convert(Func): + arg_types = {"this": True, "expression": True, "style": False, "safe": False} + + +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/CONVERT.html +class ConvertToCharset(Func): + arg_types = {"this": True, "dest": True, "source": False} + + +class ConvertTimezone(Func): + arg_types = { + "source_tz": False, + "target_tz": True, + "timestamp": True, + "options": False, + } + + +class CodePointsToString(Func): + pass + + +class GenerateSeries(Func): + arg_types = {"start": True, "end": True, "step": False, "is_end_exclusive": False} + + +# Postgres' GENERATE_SERIES function returns a row set, i.e. it implicitly explodes when it's +# used in a projection, so this expression is a helper that facilitates transpilation to other +# dialects. For example, we'd generate UNNEST(GENERATE_SERIES(...)) in DuckDB +class ExplodingGenerateSeries(GenerateSeries): + pass + + +class ArrayAgg(AggFunc): + arg_types = {"this": True, "nulls_excluded": False} + + +class ArrayUniqueAgg(AggFunc): + pass + + +class AIAgg(AggFunc): + arg_types = {"this": True, "expression": True} + _sql_names = ["AI_AGG"] + + +class AISummarizeAgg(AggFunc): + _sql_names = ["AI_SUMMARIZE_AGG"] + + +class AIClassify(Func): + arg_types = {"this": True, "categories": True, "config": False} + _sql_names = ["AI_CLASSIFY"] + + +class ArrayAll(Func): + arg_types = {"this": True, "expression": True} + + +# Represents Python's `any(f(x) for x in array)`, where `array` is `this` and `f` is `expression` +class ArrayAny(Func): + arg_types = {"this": True, "expression": True} + + +class ArrayConcat(Func): + _sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"] + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class ArrayConcatAgg(AggFunc): + pass + + +class ArrayConstructCompact(Func): + arg_types = {"expressions": False} + is_var_len_args = True + + +class ArrayContains(Binary, Func): + arg_types = {"this": True, "expression": True, "ensure_variant": False} + _sql_names = ["ARRAY_CONTAINS", "ARRAY_HAS"] + + +class ArrayContainsAll(Binary, Func): + _sql_names = ["ARRAY_CONTAINS_ALL", "ARRAY_HAS_ALL"] + + +class ArrayFilter(Func): + arg_types = {"this": True, "expression": True} + _sql_names = ["FILTER", "ARRAY_FILTER"] + + +class ArrayFirst(Func): + pass + + +class ArrayLast(Func): + pass + + +class ArrayReverse(Func): + pass + + +class ArraySlice(Func): + arg_types = {"this": True, "start": True, "end": False, "step": False} + + +class ArrayToString(Func): + arg_types = {"this": True, "expression": True, "null": False} + _sql_names = ["ARRAY_TO_STRING", "ARRAY_JOIN"] + + +class ArrayIntersect(Func): + arg_types = {"expressions": True} + is_var_len_args = True + _sql_names = ["ARRAY_INTERSECT", "ARRAY_INTERSECTION"] + + +class StPoint(Func): + arg_types = {"this": True, "expression": True, "null": False} + _sql_names = ["ST_POINT", "ST_MAKEPOINT"] + + +class StDistance(Func): + arg_types = {"this": True, "expression": True, "use_spheroid": False} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/timestamp_functions#string +class String(Func): + arg_types = {"this": True, "zone": False} + + +class StringToArray(Func): + arg_types = {"this": True, "expression": False, "null": False} + _sql_names = ["STRING_TO_ARRAY", "SPLIT_BY_STRING", "STRTOK_TO_ARRAY"] + + +class ArrayOverlaps(Binary, Func): + pass + + +class ArraySize(Func): + arg_types = {"this": True, "expression": False} + _sql_names = ["ARRAY_SIZE", "ARRAY_LENGTH"] + + +class ArraySort(Func): + arg_types = {"this": True, "expression": False} + + +class ArraySum(Func): + arg_types = {"this": True, "expression": False} + + +class ArrayUnionAgg(AggFunc): + pass + + +class Avg(AggFunc): + pass + + +class AnyValue(AggFunc): + pass + + +class Lag(AggFunc): + arg_types = {"this": True, "offset": False, "default": False} + + +class Lead(AggFunc): + arg_types = {"this": True, "offset": False, "default": False} + + +# some dialects have a distinction between first and first_value, usually first is an aggregate func +# and first_value is a window func +class First(AggFunc): + arg_types = {"this": True, "expression": False} + + +class Last(AggFunc): + arg_types = {"this": True, "expression": False} + + +class FirstValue(AggFunc): + pass + + +class LastValue(AggFunc): + pass + + +class NthValue(AggFunc): + arg_types = {"this": True, "offset": True} + + +class ObjectAgg(AggFunc): + arg_types = {"this": True, "expression": True} + + +class Case(Func): + arg_types = {"this": False, "ifs": True, "default": False} + + def when( + self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts + ) -> Case: + instance = maybe_copy(self, copy) + instance.append( + "ifs", + If( + this=maybe_parse(condition, copy=copy, **opts), + true=maybe_parse(then, copy=copy, **opts), + ), + ) + return instance + + def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case: + instance = maybe_copy(self, copy) + instance.set("default", maybe_parse(condition, copy=copy, **opts)) + return instance + + +class Cast(Func): + arg_types = { + "this": True, + "to": True, + "format": False, + "safe": False, + "action": False, + "default": False, + } + + @property + def name(self) -> str: + return self.this.name + + @property + def to(self) -> DataType: + return self.args["to"] + + @property + def output_name(self) -> str: + return self.name + + def is_type(self, *dtypes: DATA_TYPE) -> bool: + """ + Checks whether this Cast's DataType matches one of the provided data types. Nested types + like arrays or structs will be compared using "structural equivalence" semantics, so e.g. + array != array. + + Args: + dtypes: the data types to compare this Cast's DataType to. + + Returns: + True, if and only if there is a type in `dtypes` which is equal to this Cast's DataType. + """ + return self.to.is_type(*dtypes) + + +class TryCast(Cast): + arg_types = {**Cast.arg_types, "requires_string": False} + + +# https://clickhouse.com/docs/sql-reference/data-types/newjson#reading-json-paths-as-sub-columns +class JSONCast(Cast): + pass + + +class JustifyDays(Func): + pass + + +class JustifyHours(Func): + pass + + +class JustifyInterval(Func): + pass + + +class Try(Func): + pass + + +class CastToStrType(Func): + arg_types = {"this": True, "to": True} + + +class CheckJson(Func): + arg_types = {"this": True} + + +class CheckXml(Func): + arg_types = {"this": True, "disable_auto_convert": False} + + +# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Functions-Expressions-and-Predicates/String-Operators-and-Functions/TRANSLATE/TRANSLATE-Function-Syntax +class TranslateCharacters(Expression): + arg_types = {"this": True, "expression": True, "with_error": False} + + +class Collate(Binary, Func): + pass + + +class Collation(Func): + pass + + +class Ceil(Func): + arg_types = {"this": True, "decimals": False, "to": False} + _sql_names = ["CEIL", "CEILING"] + + +class Coalesce(Func): + arg_types = {"this": True, "expressions": False, "is_nvl": False, "is_null": False} + is_var_len_args = True + _sql_names = ["COALESCE", "IFNULL", "NVL"] + + +class Chr(Func): + arg_types = {"expressions": True, "charset": False} + is_var_len_args = True + _sql_names = ["CHR", "CHAR"] + + +class Concat(Func): + arg_types = {"expressions": True, "safe": False, "coalesce": False} + is_var_len_args = True + + +class ConcatWs(Concat): + _sql_names = ["CONCAT_WS"] + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#contains_substr +class Contains(Func): + arg_types = {"this": True, "expression": True, "json_scope": False} + + +# https://docs.oracle.com/cd/B13789_01/server.101/b10759/operators004.htm#i1035022 +class ConnectByRoot(Func): + pass + + +class Count(AggFunc): + arg_types = {"this": False, "expressions": False, "big_int": False} + is_var_len_args = True + + +class CountIf(AggFunc): + _sql_names = ["COUNT_IF", "COUNTIF"] + + +# cube root +class Cbrt(Func): + pass + + +class CurrentAccount(Func): + arg_types = {} + + +class CurrentAccountName(Func): + arg_types = {} + + +class CurrentAvailableRoles(Func): + arg_types = {} + + +class CurrentClient(Func): + arg_types = {} + + +class CurrentIpAddress(Func): + arg_types = {} + + +class CurrentDatabase(Func): + arg_types = {} + + +class CurrentSchemas(Func): + arg_types = {"this": False} + + +class CurrentSecondaryRoles(Func): + arg_types = {} + + +class CurrentSession(Func): + arg_types = {} + + +class CurrentStatement(Func): + arg_types = {} + + +class CurrentVersion(Func): + arg_types = {} + + +class CurrentTransaction(Func): + arg_types = {} + + +class CurrentWarehouse(Func): + arg_types = {} + + +class CurrentDate(Func): + arg_types = {"this": False} + + +class CurrentDatetime(Func): + arg_types = {"this": False} + + +class CurrentTime(Func): + arg_types = {"this": False} + + +# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-CURRENT +# In Postgres, the difference between CURRENT_TIME vs LOCALTIME etc is that the latter does not have tz +class Localtime(Func): + arg_types = {"this": False} + + +class Localtimestamp(Func): + arg_types = {"this": False} + + +class CurrentTimestamp(Func): + arg_types = {"this": False, "sysdate": False} + + +class CurrentTimestampLTZ(Func): + arg_types = {} + + +class CurrentTimezone(Func): + arg_types = {} + + +class CurrentOrganizationName(Func): + arg_types = {} + + +class CurrentSchema(Func): + arg_types = {"this": False} + + +class CurrentUser(Func): + arg_types = {"this": False} + + +class CurrentCatalog(Func): + arg_types = {} + + +class CurrentRegion(Func): + arg_types = {} + + +class CurrentRole(Func): + arg_types = {} + + +class CurrentRoleType(Func): + arg_types = {} + + +class CurrentOrganizationUser(Func): + arg_types = {} + + +class SessionUser(Func): + arg_types = {} + + +class UtcDate(Func): + arg_types = {} + + +class UtcTime(Func): + arg_types = {"this": False} + + +class UtcTimestamp(Func): + arg_types = {"this": False} + + +class DateAdd(Func, IntervalOp): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DateBin(Func, IntervalOp): + arg_types = { + "this": True, + "expression": True, + "unit": False, + "zone": False, + "origin": False, + } + + +class DateSub(Func, IntervalOp): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DateDiff(Func, TimeUnit): + _sql_names = ["DATEDIFF", "DATE_DIFF"] + arg_types = { + "this": True, + "expression": True, + "unit": False, + "zone": False, + "big_int": False, + "date_part_boundary": False, + } + + +class DateTrunc(Func): + arg_types = {"unit": True, "this": True, "zone": False} + + def __init__(self, **args): + # Across most dialects it's safe to unabbreviate the unit (e.g. 'Q' -> 'QUARTER') except Oracle + # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html + unabbreviate = args.pop("unabbreviate", True) + + unit = args.get("unit") + if isinstance(unit, TimeUnit.VAR_LIKE) and not ( + isinstance(unit, Column) and len(unit.parts) != 1 + ): + unit_name = unit.name.upper() + if unabbreviate and unit_name in TimeUnit.UNABBREVIATED_UNIT_NAME: + unit_name = TimeUnit.UNABBREVIATED_UNIT_NAME[unit_name] + + args["unit"] = Literal.string(unit_name) + + super().__init__(**args) + + @property + def unit(self) -> Expression: + return self.args["unit"] + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/datetime_functions#datetime +# expression can either be time_expr or time_zone +class Datetime(Func): + arg_types = {"this": True, "expression": False} + + +class DatetimeAdd(Func, IntervalOp): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeSub(Func, IntervalOp): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class DateFromUnixDate(Func): + pass + + +class DayOfWeek(Func): + _sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"] + + +# https://duckdb.org/docs/sql/functions/datepart.html#part-specifiers-only-usable-as-date-part-specifiers +# ISO day of week function in duckdb is ISODOW +class DayOfWeekIso(Func): + _sql_names = ["DAYOFWEEK_ISO", "ISODOW"] + + +class DayOfMonth(Func): + _sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"] + + +class DayOfYear(Func): + _sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"] + + +class Dayname(Func): + arg_types = {"this": True, "abbreviated": False} + + +class ToDays(Func): + pass + + +class WeekOfYear(Func): + _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"] + + +class YearOfWeek(Func): + _sql_names = ["YEAR_OF_WEEK", "YEAROFWEEK"] + + +class YearOfWeekIso(Func): + _sql_names = ["YEAR_OF_WEEK_ISO", "YEAROFWEEKISO"] + + +class MonthsBetween(Func): + arg_types = {"this": True, "expression": True, "roundoff": False} + + +class MakeInterval(Func): + arg_types = { + "year": False, + "month": False, + "week": False, + "day": False, + "hour": False, + "minute": False, + "second": False, + } + + +class LastDay(Func, TimeUnit): + _sql_names = ["LAST_DAY", "LAST_DAY_OF_MONTH"] + arg_types = {"this": True, "unit": False} + + +class PreviousDay(Func): + arg_types = {"this": True, "expression": True} + + +class LaxBool(Func): + pass + + +class LaxFloat64(Func): + pass + + +class LaxInt64(Func): + pass + + +class LaxString(Func): + pass + + +class Extract(Func): + arg_types = {"this": True, "expression": True} + + +class Exists(Func, SubqueryPredicate): + arg_types = {"this": True, "expression": False} + + +class Elt(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class Timestamp(Func): + arg_types = {"this": False, "zone": False, "with_tz": False, "safe": False} + + +class TimestampAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampDiff(Func, TimeUnit): + _sql_names = ["TIMESTAMPDIFF", "TIMESTAMP_DIFF"] + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class TimeSlice(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": True, "kind": False} + + +class TimeAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class DateFromParts(Func): + _sql_names = ["DATE_FROM_PARTS", "DATEFROMPARTS"] + arg_types = {"year": True, "month": False, "day": False} + + +class TimeFromParts(Func): + _sql_names = ["TIME_FROM_PARTS", "TIMEFROMPARTS"] + arg_types = { + "hour": True, + "min": True, + "sec": True, + "nano": False, + "fractions": False, + "precision": False, + } + + +class DateStrToDate(Func): + pass + + +class DateToDateStr(Func): + pass + + +class DateToDi(Func): + pass + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date +class Date(Func): + arg_types = {"this": False, "zone": False, "expressions": False} + is_var_len_args = True + + +class Day(Func): + pass + + +class Decode(Func): + arg_types = {"this": True, "charset": True, "replace": False} + + +class DecodeCase(Func): + arg_types = {"expressions": True} + is_var_len_args = True + + +class DenseRank(AggFunc): + arg_types = {"expressions": False} + is_var_len_args = True + + +class DiToDate(Func): + pass + + +class Encode(Func): + arg_types = {"this": True, "charset": True} + + +class EqualNull(Func): + arg_types = {"this": True, "expression": True} + + +class Exp(Func): + pass + + +class Factorial(Func): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/flatten +class Explode(Func, UDTF): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +# https://spark.apache.org/docs/latest/api/sql/#inline +class Inline(Func): + pass + + +class ExplodeOuter(Explode): + pass + + +class Posexplode(Explode): + pass + + +class PosexplodeOuter(Posexplode, ExplodeOuter): + pass + + +class PositionalColumn(Expression): + pass + + +class Unnest(Func, UDTF): + arg_types = { + "expressions": True, + "alias": False, + "offset": False, + "explode_array": False, + } + + @property + def selects(self) -> t.List[Expression]: + columns = super().selects + offset = self.args.get("offset") + if offset: + columns = columns + [to_identifier("offset") if offset is True else offset] + return columns + + +class Floor(Func): + arg_types = {"this": True, "decimals": False, "to": False} + + +class FromBase32(Func): + pass + + +class FromBase64(Func): + pass + + +class ToBase32(Func): + pass + + +class ToBase64(Func): + pass + + +class ToBinary(Func): + arg_types = {"this": True, "format": False, "safe": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/base64_decode_binary +class Base64DecodeBinary(Func): + arg_types = {"this": True, "alphabet": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/base64_decode_string +class Base64DecodeString(Func): + arg_types = {"this": True, "alphabet": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/base64_encode +class Base64Encode(Func): + arg_types = {"this": True, "max_line_length": False, "alphabet": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/try_base64_decode_binary +class TryBase64DecodeBinary(Func): + arg_types = {"this": True, "alphabet": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/try_base64_decode_string +class TryBase64DecodeString(Func): + arg_types = {"this": True, "alphabet": False} + + +# https://docs.snowflake.com/en/sql-reference/functions/try_hex_decode_binary +class TryHexDecodeBinary(Func): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/try_hex_decode_string +class TryHexDecodeString(Func): + pass + + +# https://trino.io/docs/current/functions/datetime.html#from_iso8601_timestamp +class FromISO8601Timestamp(Func): + _sql_names = ["FROM_ISO8601_TIMESTAMP"] + + +class GapFill(Func): + arg_types = { + "this": True, + "ts_column": True, + "bucket_width": True, + "partitioning_columns": False, + "value_columns": False, + "origin": False, + "ignore_nulls": False, + } + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/array_functions#generate_date_array +class GenerateDateArray(Func): + arg_types = {"start": True, "end": True, "step": False} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/array_functions#generate_timestamp_array +class GenerateTimestampArray(Func): + arg_types = {"start": True, "end": True, "step": True} + + +# https://docs.snowflake.com/en/sql-reference/functions/get +class GetExtract(Func): + arg_types = {"this": True, "expression": True} + + +class Getbit(Func): + arg_types = {"this": True, "expression": True} + + +class Greatest(Func): + arg_types = {"this": True, "expressions": False, "ignore_nulls": True} + is_var_len_args = True + + +# Trino's `ON OVERFLOW TRUNCATE [filler_string] {WITH | WITHOUT} COUNT` +# https://trino.io/docs/current/functions/aggregate.html#listagg +class OverflowTruncateBehavior(Expression): + arg_types = {"this": False, "with_count": True} + + +class GroupConcat(AggFunc): + arg_types = {"this": True, "separator": False, "on_overflow": False} + + +class Hex(Func): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/hex_decode_string +class HexDecodeString(Func): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/hex_encode +class HexEncode(Func): + arg_types = {"this": True, "case": False} + + +class Hour(Func): + pass + + +class Minute(Func): + pass + + +class Second(Func): + pass + + +# T-SQL: https://learn.microsoft.com/en-us/sql/t-sql/functions/compress-transact-sql?view=sql-server-ver17 +# Snowflake: https://docs.snowflake.com/en/sql-reference/functions/compress +class Compress(Func): + arg_types = {"this": True, "method": False} + + +# Snowflake: https://docs.snowflake.com/en/sql-reference/functions/decompress_binary +class DecompressBinary(Func): + arg_types = {"this": True, "method": True} + + +# Snowflake: https://docs.snowflake.com/en/sql-reference/functions/decompress_string +class DecompressString(Func): + arg_types = {"this": True, "method": True} + + +class LowerHex(Hex): + pass + + +class And(Connector, Func): + pass + + +class Or(Connector, Func): + pass + + +class Xor(Connector, Func): + arg_types = {"this": False, "expression": False, "expressions": False} + + +class If(Func): + arg_types = {"this": True, "true": True, "false": False} + _sql_names = ["IF", "IIF"] + + +class Nullif(Func): + arg_types = {"this": True, "expression": True} + + +class Initcap(Func): + arg_types = {"this": True, "expression": False} + + +class IsAscii(Func): + pass + + +class IsNan(Func): + _sql_names = ["IS_NAN", "ISNAN"] + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#int64_for_json +class Int64(Func): + pass + + +class IsInf(Func): + _sql_names = ["IS_INF", "ISINF"] + + +class IsNullValue(Func): + pass + + +# https://www.postgresql.org/docs/current/functions-json.html +class JSON(Expression): + arg_types = {"this": False, "with_": False, "unique": False} + + +class JSONPath(Expression): + arg_types = {"expressions": True, "escape": False} + + @property + def output_name(self) -> str: + last_segment = self.expressions[-1].this + return last_segment if isinstance(last_segment, str) else "" + + +class JSONPathPart(Expression): + arg_types = {} + + +class JSONPathFilter(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathKey(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathRecursive(JSONPathPart): + arg_types = {"this": False} + + +class JSONPathRoot(JSONPathPart): + pass + + +class JSONPathScript(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathSlice(JSONPathPart): + arg_types = {"start": False, "end": False, "step": False} + + +class JSONPathSelector(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathSubscript(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathUnion(JSONPathPart): + arg_types = {"expressions": True} + + +class JSONPathWildcard(JSONPathPart): + pass + + +class FormatJson(Expression): + pass + + +class Format(Func): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class JSONKeyValue(Expression): + arg_types = {"this": True, "expression": True} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_keys +class JSONKeysAtDepth(Func): + arg_types = {"this": True, "expression": False, "mode": False} + + +class JSONObject(Func): + arg_types = { + "expressions": False, + "null_handling": False, + "unique_keys": False, + "return_type": False, + "encoding": False, + } + + +class JSONObjectAgg(AggFunc): + arg_types = { + "expressions": False, + "null_handling": False, + "unique_keys": False, + "return_type": False, + "encoding": False, + } + + +# https://www.postgresql.org/docs/9.5/functions-aggregate.html +class JSONBObjectAgg(AggFunc): + arg_types = {"this": True, "expression": True} + + +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAY.html +class JSONArray(Func): + arg_types = { + "expressions": False, + "null_handling": False, + "return_type": False, + "strict": False, + } + + +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAYAGG.html +class JSONArrayAgg(AggFunc): + arg_types = { + "this": True, + "order": False, + "null_handling": False, + "return_type": False, + "strict": False, + } + + +class JSONExists(Func): + arg_types = { + "this": True, + "path": True, + "passing": False, + "on_condition": False, + "from_dcolonqmark": False, + } + + +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html +# Note: parsing of JSON column definitions is currently incomplete. +class JSONColumnDef(Expression): + arg_types = { + "this": False, + "kind": False, + "path": False, + "nested_schema": False, + "ordinality": False, + } + + +class JSONSchema(Expression): + arg_types = {"expressions": True} + + +class JSONSet(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + _sql_names = ["JSON_SET"] + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_strip_nulls +class JSONStripNulls(Func): + arg_types = { + "this": True, + "expression": False, + "include_arrays": False, + "remove_empty": False, + } + _sql_names = ["JSON_STRIP_NULLS"] + + +# https://dev.mysql.com/doc/refman/8.4/en/json-search-functions.html#function_json-value +class JSONValue(Expression): + arg_types = { + "this": True, + "path": True, + "returning": False, + "on_condition": False, + } + + +class JSONValueArray(Func): + arg_types = {"this": True, "expression": False} + + +class JSONRemove(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + _sql_names = ["JSON_REMOVE"] + + +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html +class JSONTable(Func): + arg_types = { + "this": True, + "schema": True, + "path": False, + "error_handling": False, + "empty_handling": False, + } + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_type +# https://doris.apache.org/docs/sql-manual/sql-functions/scalar-functions/json-functions/json-type#description +class JSONType(Func): + arg_types = {"this": True, "expression": False} + _sql_names = ["JSON_TYPE"] + + +# https://docs.snowflake.com/en/sql-reference/functions/object_insert +class ObjectInsert(Func): + arg_types = { + "this": True, + "key": True, + "value": True, + "update_flag": False, + } + + +class OpenJSONColumnDef(Expression): + arg_types = {"this": True, "kind": True, "path": False, "as_json": False} + + +class OpenJSON(Func): + arg_types = {"this": True, "path": False, "expressions": False} + + +class JSONBContains(Binary, Func): + _sql_names = ["JSONB_CONTAINS"] + + +# https://www.postgresql.org/docs/9.5/functions-json.html +class JSONBContainsAnyTopKeys(Binary, Func): + pass + + +# https://www.postgresql.org/docs/9.5/functions-json.html +class JSONBContainsAllTopKeys(Binary, Func): + pass + + +class JSONBExists(Func): + arg_types = {"this": True, "path": True} + _sql_names = ["JSONB_EXISTS"] + + +# https://www.postgresql.org/docs/9.5/functions-json.html +class JSONBDeleteAtPath(Binary, Func): + pass + + +class JSONExtract(Binary, Func): + arg_types = { + "this": True, + "expression": True, + "only_json_types": False, + "expressions": False, + "variant_extract": False, + "json_query": False, + "option": False, + "quote": False, + "on_condition": False, + "requires_json": False, + } + _sql_names = ["JSON_EXTRACT"] + is_var_len_args = True + + @property + def output_name(self) -> str: + return self.expression.output_name if not self.expressions else "" + + +# https://trino.io/docs/current/functions/json.html#json-query +class JSONExtractQuote(Expression): + arg_types = { + "option": True, + "scalar": False, + } + + +class JSONExtractArray(Func): + arg_types = {"this": True, "expression": False} + _sql_names = ["JSON_EXTRACT_ARRAY"] + + +class JSONExtractScalar(Binary, Func): + arg_types = { + "this": True, + "expression": True, + "only_json_types": False, + "expressions": False, + "json_type": False, + "scalar_only": False, + } + _sql_names = ["JSON_EXTRACT_SCALAR"] + is_var_len_args = True + + @property + def output_name(self) -> str: + return self.expression.output_name + + +class JSONBExtract(Binary, Func): + _sql_names = ["JSONB_EXTRACT"] + + +class JSONBExtractScalar(Binary, Func): + arg_types = {"this": True, "expression": True, "json_type": False} + _sql_names = ["JSONB_EXTRACT_SCALAR"] + + +class JSONFormat(Func): + arg_types = {"this": False, "options": False, "is_json": False, "to_json": False} + _sql_names = ["JSON_FORMAT"] + + +class JSONArrayAppend(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + _sql_names = ["JSON_ARRAY_APPEND"] + + +# https://dev.mysql.com/doc/refman/8.0/en/json-search-functions.html#operator_member-of +class JSONArrayContains(Binary, Predicate, Func): + arg_types = {"this": True, "expression": True, "json_type": False} + _sql_names = ["JSON_ARRAY_CONTAINS"] + + +class JSONArrayInsert(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + _sql_names = ["JSON_ARRAY_INSERT"] + + +class ParseBignumeric(Func): + pass + + +class ParseNumeric(Func): + pass + + +class ParseJSON(Func): + # BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE + # Snowflake also has TRY_PARSE_JSON, which is represented using `safe` + _sql_names = ["PARSE_JSON", "JSON_PARSE"] + arg_types = {"this": True, "expression": False, "safe": False} + + +# Snowflake: https://docs.snowflake.com/en/sql-reference/functions/parse_url +# Databricks: https://docs.databricks.com/aws/en/sql/language-manual/functions/parse_url +class ParseUrl(Func): + arg_types = { + "this": True, + "part_to_extract": False, + "key": False, + "permissive": False, + } + + +class ParseIp(Func): + arg_types = {"this": True, "type": True, "permissive": False} + + +class ParseTime(Func): + arg_types = {"this": True, "format": True} + + +class ParseDatetime(Func): + arg_types = {"this": True, "format": False, "zone": False} + + +class Least(Func): + arg_types = {"this": True, "expressions": False, "ignore_nulls": True} + is_var_len_args = True + + +class Left(Func): + arg_types = {"this": True, "expression": True} + + +class Right(Func): + arg_types = {"this": True, "expression": True} + + +class Reverse(Func): + pass + + +class Length(Func): + arg_types = {"this": True, "binary": False, "encoding": False} + _sql_names = ["LENGTH", "LEN", "CHAR_LENGTH", "CHARACTER_LENGTH"] + + +class RtrimmedLength(Func): + pass + + +class BitLength(Func): + pass + + +class Levenshtein(Func): + arg_types = { + "this": True, + "expression": False, + "ins_cost": False, + "del_cost": False, + "sub_cost": False, + "max_dist": False, + } + + +class Ln(Func): + pass + + +class Log(Func): + arg_types = {"this": True, "expression": False} + + +class LogicalOr(AggFunc): + _sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"] + + +class LogicalAnd(AggFunc): + _sql_names = ["LOGICAL_AND", "BOOL_AND", "BOOLAND_AGG"] + + +class Lower(Func): + _sql_names = ["LOWER", "LCASE"] + + +class Map(Func): + arg_types = {"keys": False, "values": False} + + @property + def keys(self) -> t.List[Expression]: + keys = self.args.get("keys") + return keys.expressions if keys else [] + + @property + def values(self) -> t.List[Expression]: + values = self.args.get("values") + return values.expressions if values else [] + + +# Represents the MAP {...} syntax in DuckDB - basically convert a struct to a MAP +class ToMap(Func): + pass + + +class MapFromEntries(Func): + pass + + +class MapCat(Func): + arg_types = {"this": True, "expression": True} + + +class MapContainsKey(Func): + arg_types = {"this": True, "key": True} + + +class MapDelete(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class MapInsert(Func): + arg_types = {"this": True, "key": False, "value": True, "update_flag": False} + + +class MapKeys(Func): + pass + + +class MapPick(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class MapSize(Func): + pass + + +# https://learn.microsoft.com/en-us/sql/t-sql/language-elements/scope-resolution-operator-transact-sql?view=sql-server-ver16 +class ScopeResolution(Expression): + arg_types = {"this": False, "expression": True} + + +class Slice(Expression): + arg_types = {"this": False, "expression": False, "step": False} + + +class Stream(Expression): + pass + + +class StarMap(Func): + pass + + +class VarMap(Func): + arg_types = {"keys": True, "values": True} + is_var_len_args = True + + @property + def keys(self) -> t.List[Expression]: + return self.args["keys"].expressions + + @property + def values(self) -> t.List[Expression]: + return self.args["values"].expressions + + +# https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html +class MatchAgainst(Func): + arg_types = {"this": True, "expressions": True, "modifier": False} + + +class Max(AggFunc): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class MD5(Func): + _sql_names = ["MD5"] + + +# Represents the variant of the MD5 function that returns a binary value +class MD5Digest(Func): + _sql_names = ["MD5_DIGEST"] + + +# https://docs.snowflake.com/en/sql-reference/functions/md5_number_lower64 +class MD5NumberLower64(Func): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/md5_number_upper64 +class MD5NumberUpper64(Func): + pass + + +class Median(AggFunc): + pass + + +class Mode(AggFunc): + arg_types = {"this": False, "deterministic": False} + + +class Min(AggFunc): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class Month(Func): + pass + + +class Monthname(Func): + arg_types = {"this": True, "abbreviated": False} + + +class AddMonths(Func): + arg_types = {"this": True, "expression": True, "preserve_end_of_month": False} + + +class Nvl2(Func): + arg_types = {"this": True, "true": True, "false": False} + + +class Ntile(AggFunc): + arg_types = {"this": False} + + +class Normalize(Func): + arg_types = {"this": True, "form": False, "is_casefold": False} + + +class Normal(Func): + arg_types = {"this": True, "stddev": True, "gen": True} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/net_functions#nethost +class NetHost(Func): + _sql_names = ["NET.HOST"] + + +class Overlay(Func): + arg_types = {"this": True, "expression": True, "from_": True, "for_": False} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function +class Predict(Func): + arg_types = {"this": True, "expression": True, "params_struct": False} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-translate#mltranslate_function +class MLTranslate(Func): + arg_types = {"this": True, "expression": True, "params_struct": True} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-feature-time +class FeaturesAtTime(Func): + arg_types = { + "this": True, + "time": False, + "num_rows": False, + "ignore_feature_nulls": False, + } + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-embedding +class GenerateEmbedding(Func): + arg_types = { + "this": True, + "expression": True, + "params_struct": False, + "is_text": False, + } + + +class MLForecast(Func): + arg_types = {"this": True, "expression": False, "params_struct": False} + + +# Represents Snowflake's ! syntax. For example: SELECT model!PREDICT(INPUT_DATA => {*}) +# See: https://docs.snowflake.com/en/guides-overview-ml-functions +class ModelAttribute(Expression): + arg_types = {"this": True, "expression": True} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/search_functions#vector_search +class VectorSearch(Func): + arg_types = { + "this": True, + "column_to_search": True, + "query_table": True, + "query_column_to_search": False, + "top_k": False, + "distance_type": False, + "options": False, + } + + +class Pi(Func): + arg_types = {} + + +class Pow(Binary, Func): + _sql_names = ["POWER", "POW"] + + +class PercentileCont(AggFunc): + arg_types = {"this": True, "expression": False} + + +class PercentileDisc(AggFunc): + arg_types = {"this": True, "expression": False} + + +class PercentRank(AggFunc): + arg_types = {"expressions": False} + is_var_len_args = True + + +class Quantile(AggFunc): + arg_types = {"this": True, "quantile": True} + + +class ApproxQuantile(Quantile): + arg_types = { + "this": True, + "quantile": True, + "accuracy": False, + "weight": False, + "error_tolerance": False, + } + + +# https://docs.snowflake.com/en/sql-reference/functions/approx_percentile_accumulate +class ApproxPercentileAccumulate(AggFunc): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/approx_percentile_estimate +class ApproxPercentileEstimate(Func): + arg_types = {"this": True, "percentile": True} + + +class Quarter(Func): + pass + + +# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Functions-Expressions-and-Predicates/Arithmetic-Trigonometric-Hyperbolic-Operators/Functions/RANDOM/RANDOM-Function-Syntax +# teradata lower and upper bounds +class Rand(Func): + _sql_names = ["RAND", "RANDOM"] + arg_types = {"this": False, "lower": False, "upper": False} + + +class Randn(Func): + arg_types = {"this": False} + + +class Randstr(Func): + arg_types = {"this": True, "generator": False} + + +class RangeN(Func): + arg_types = {"this": True, "expressions": True, "each": False} + + +class RangeBucket(Func): + arg_types = {"this": True, "expression": True} + + +class Rank(AggFunc): + arg_types = {"expressions": False} + is_var_len_args = True + + +class ReadCSV(Func): + _sql_names = ["READ_CSV"] + is_var_len_args = True + arg_types = {"this": True, "expressions": False} + + +class ReadParquet(Func): + is_var_len_args = True + arg_types = {"expressions": True} + + +class Reduce(Func): + arg_types = {"this": True, "initial": True, "merge": True, "finish": False} + + +class RegexpExtract(Func): + arg_types = { + "this": True, + "expression": True, + "position": False, + "occurrence": False, + "parameters": False, + "group": False, + "null_if_pos_overflow": False, # for transpilation target behavior + } + + +class RegexpExtractAll(Func): + arg_types = { + "this": True, + "expression": True, + "group": False, + "parameters": False, + "position": False, + "occurrence": False, + } + + +class RegexpReplace(Func): + arg_types = { + "this": True, + "expression": True, + "replacement": False, + "position": False, + "occurrence": False, + "modifiers": False, + "single_replace": False, + } + + +class RegexpLike(Binary, Func): + arg_types = {"this": True, "expression": True, "flag": False} + + +class RegexpILike(Binary, Func): + arg_types = {"this": True, "expression": True, "flag": False} + + +class RegexpFullMatch(Binary, Func): + arg_types = {"this": True, "expression": True, "options": False} + + +class RegexpInstr(Func): + arg_types = { + "this": True, + "expression": True, + "position": False, + "occurrence": False, + "option": False, + "parameters": False, + "group": False, + } + + +# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html +# limit is the number of times a pattern is applied +class RegexpSplit(Func): + arg_types = {"this": True, "expression": True, "limit": False} + + +class RegexpCount(Func): + arg_types = { + "this": True, + "expression": True, + "position": False, + "parameters": False, + } + + +class RegrValx(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrValy(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrAvgy(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrAvgx(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrCount(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrIntercept(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrR2(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrSxx(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrSxy(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrSyy(AggFunc): + arg_types = {"this": True, "expression": True} + + +class RegrSlope(AggFunc): + arg_types = {"this": True, "expression": True} + + +class Repeat(Func): + arg_types = {"this": True, "times": True} + + +# Some dialects like Snowflake support two argument replace +class Replace(Func): + arg_types = {"this": True, "expression": True, "replacement": False} + + +class Radians(Func): + pass + + +# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16 +# tsql third argument function == trunctaion if not 0 +class Round(Func): + arg_types = { + "this": True, + "decimals": False, + "truncate": False, + "casts_non_integer_decimals": False, + } + + +class RowNumber(Func): + arg_types = {"this": False} + + +class SafeAdd(Func): + arg_types = {"this": True, "expression": True} + + +class SafeDivide(Func): + arg_types = {"this": True, "expression": True} + + +class SafeMultiply(Func): + arg_types = {"this": True, "expression": True} + + +class SafeNegate(Func): + pass + + +class SafeSubtract(Func): + arg_types = {"this": True, "expression": True} + + +class SafeConvertBytesToString(Func): + pass + + +class SHA(Func): + _sql_names = ["SHA", "SHA1"] + + +class SHA2(Func): + _sql_names = ["SHA2"] + arg_types = {"this": True, "length": False} + + +# Represents the variant of the SHA1 function that returns a binary value +class SHA1Digest(Func): + pass + + +# Represents the variant of the SHA2 function that returns a binary value +class SHA2Digest(Func): + arg_types = {"this": True, "length": False} + + +class Sign(Func): + _sql_names = ["SIGN", "SIGNUM"] + + +class SortArray(Func): + arg_types = {"this": True, "asc": False, "nulls_first": False} + + +class Soundex(Func): + pass + + +# https://docs.snowflake.com/en/sql-reference/functions/soundex_p123 +class SoundexP123(Func): + pass + + +class Split(Func): + arg_types = {"this": True, "expression": True, "limit": False} + + +# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split_part.html +# https://docs.snowflake.com/en/sql-reference/functions/split_part +# https://docs.snowflake.com/en/sql-reference/functions/strtok +class SplitPart(Func): + arg_types = {"this": True, "delimiter": False, "part_index": False} + + +# Start may be omitted in the case of postgres +# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 +class Substring(Func): + _sql_names = ["SUBSTRING", "SUBSTR"] + arg_types = {"this": True, "start": False, "length": False} + + +class SubstringIndex(Func): + """ + SUBSTRING_INDEX(str, delim, count) + + *count* > 0 → left slice before the *count*-th delimiter + *count* < 0 → right slice after the |count|-th delimiter + """ + + arg_types = {"this": True, "delimiter": True, "count": True} + + +class StandardHash(Func): + arg_types = {"this": True, "expression": False} + + +class StartsWith(Func): + _sql_names = ["STARTS_WITH", "STARTSWITH"] + arg_types = {"this": True, "expression": True} + + +class EndsWith(Func): + _sql_names = ["ENDS_WITH", "ENDSWITH"] + arg_types = {"this": True, "expression": True} + + +class StrPosition(Func): + arg_types = { + "this": True, + "substr": True, + "position": False, + "occurrence": False, + } + + +# Snowflake: https://docs.snowflake.com/en/sql-reference/functions/search +# BigQuery: https://cloud.google.com/bigquery/docs/reference/standard-sql/search_functions#search +class Search(Func): + arg_types = { + "this": True, # data_to_search / search_data + "expression": True, # search_query / search_string + "json_scope": False, # BigQuery: JSON_VALUES | JSON_KEYS | JSON_KEYS_AND_VALUES + "analyzer": False, # Both: analyzer / ANALYZER + "analyzer_options": False, # BigQuery: analyzer_options_values + "search_mode": False, # Snowflake: OR | AND + } + + +# Snowflake: https://docs.snowflake.com/en/sql-reference/functions/search_ip +class SearchIp(Func): + arg_types = {"this": True, "expression": True} + + +class StrToDate(Func): + arg_types = {"this": True, "format": False, "safe": False} + + +class StrToTime(Func): + arg_types = { + "this": True, + "format": True, + "zone": False, + "safe": False, + "target_type": False, + } + + +# Spark allows unix_timestamp() +# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html +class StrToUnix(Func): + arg_types = {"this": False, "format": False} + + +# https://prestodb.io/docs/current/functions/string.html +# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map +class StrToMap(Func): + arg_types = { + "this": True, + "pair_delim": False, + "key_value_delim": False, + "duplicate_resolution_callback": False, + } + + +class NumberToStr(Func): + arg_types = {"this": True, "format": True, "culture": False} + + +class FromBase(Func): + arg_types = {"this": True, "expression": True} + + +class Space(Func): + """ + SPACE(n) → string consisting of n blank characters + """ + + pass + + +class Struct(Func): + arg_types = {"expressions": False} + is_var_len_args = True + + +class StructExtract(Func): + arg_types = {"this": True, "expression": True} + + +# https://learn.microsoft.com/en-us/sql/t-sql/functions/stuff-transact-sql?view=sql-server-ver16 +# https://docs.snowflake.com/en/sql-reference/functions/insert +class Stuff(Func): + _sql_names = ["STUFF", "INSERT"] + arg_types = {"this": True, "start": True, "length": True, "expression": True} + + +class Sum(AggFunc): + pass + + +class Sqrt(Func): + pass + + +class Stddev(AggFunc): + _sql_names = ["STDDEV", "STDEV"] + + +class StddevPop(AggFunc): + pass + + +class StddevSamp(AggFunc): + pass + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/time_functions#time +class Time(Func): + arg_types = {"this": False, "zone": False} + + +class TimeToStr(Func): + arg_types = {"this": True, "format": True, "culture": False, "zone": False} + + +class TimeToTimeStr(Func): + pass + + +class TimeToUnix(Func): + pass + + +class TimeStrToDate(Func): + pass + + +class TimeStrToTime(Func): + arg_types = {"this": True, "zone": False} + + +class TimeStrToUnix(Func): + pass + + +class Trim(Func): + arg_types = { + "this": True, + "expression": False, + "position": False, + "collation": False, + } + + +class TsOrDsAdd(Func, TimeUnit): + # return_type is used to correctly cast the arguments of this expression when transpiling it + arg_types = {"this": True, "expression": True, "unit": False, "return_type": False} + + @property + def return_type(self) -> DataType: + return DataType.build(self.args.get("return_type") or DataType.Type.DATE) + + +class TsOrDsDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TsOrDsToDateStr(Func): + pass + + +class TsOrDsToDate(Func): + arg_types = {"this": True, "format": False, "safe": False} + + +class TsOrDsToDatetime(Func): + pass + + +class TsOrDsToTime(Func): + arg_types = {"this": True, "format": False, "safe": False} + + +class TsOrDsToTimestamp(Func): + pass + + +class TsOrDiToDi(Func): + pass + + +class Unhex(Func): + arg_types = {"this": True, "expression": False} + + +class Unicode(Func): + pass + + +class Uniform(Func): + arg_types = {"this": True, "expression": True, "gen": False, "seed": False} + + +# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#unix_date +class UnixDate(Func): + pass + + +class UnixToStr(Func): + arg_types = {"this": True, "format": False} + + +# https://prestodb.io/docs/current/functions/datetime.html +# presto has weird zone/hours/minutes +class UnixToTime(Func): + arg_types = { + "this": True, + "scale": False, + "zone": False, + "hours": False, + "minutes": False, + "format": False, + } + + SECONDS = Literal.number(0) + DECIS = Literal.number(1) + CENTIS = Literal.number(2) + MILLIS = Literal.number(3) + DECIMILLIS = Literal.number(4) + CENTIMILLIS = Literal.number(5) + MICROS = Literal.number(6) + DECIMICROS = Literal.number(7) + CENTIMICROS = Literal.number(8) + NANOS = Literal.number(9) + + +class UnixToTimeStr(Func): + pass + + +class UnixSeconds(Func): + pass + + +class UnixMicros(Func): + pass + + +class UnixMillis(Func): + pass + + +class Uuid(Func): + _sql_names = ["UUID", "GEN_RANDOM_UUID", "GENERATE_UUID", "UUID_STRING"] + + arg_types = {"this": False, "name": False, "is_string": False} + + +TIMESTAMP_PARTS = { + "year": False, + "month": False, + "day": False, + "hour": False, + "min": False, + "sec": False, + "nano": False, +} + + +class TimestampFromParts(Func): + _sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"] + arg_types = { + **TIMESTAMP_PARTS, + "zone": False, + "milli": False, + "this": False, + "expression": False, + } + + +class TimestampLtzFromParts(Func): + _sql_names = ["TIMESTAMP_LTZ_FROM_PARTS", "TIMESTAMPLTZFROMPARTS"] + arg_types = TIMESTAMP_PARTS.copy() + + +class TimestampTzFromParts(Func): + _sql_names = ["TIMESTAMP_TZ_FROM_PARTS", "TIMESTAMPTZFROMPARTS"] + arg_types = { + **TIMESTAMP_PARTS, + "zone": False, + } + + +class Upper(Func): + _sql_names = ["UPPER", "UCASE"] + + +class Corr(Binary, AggFunc): + pass + + +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/CUME_DIST.html +class CumeDist(AggFunc): + arg_types = {"expressions": False} + is_var_len_args = True + + +class Variance(AggFunc): + _sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"] + + +class VariancePop(AggFunc): + _sql_names = ["VARIANCE_POP", "VAR_POP"] + + +class Skewness(AggFunc): + pass + + +class WidthBucket(Func): + arg_types = { + "this": True, + "min_value": True, + "max_value": True, + "num_buckets": True, + } + + +class CovarSamp(Binary, AggFunc): + pass + + +class CovarPop(Binary, AggFunc): + pass + + +class Week(Func): + arg_types = {"this": True, "mode": False} + + +class WeekStart(Expression): + pass + + +class NextDay(Func): + arg_types = {"this": True, "expression": True} + + +class XMLElement(Func): + _sql_names = ["XMLELEMENT"] + arg_types = {"this": True, "expressions": False} + + +class XMLGet(Func): + _sql_names = ["XMLGET"] + arg_types = {"this": True, "expression": True, "instance": False} + + +class XMLTable(Func): + arg_types = { + "this": True, + "namespaces": False, + "passing": False, + "columns": False, + "by_ref": False, + } + + +class XMLNamespace(Expression): + pass + + +# https://learn.microsoft.com/en-us/sql/t-sql/queries/select-for-clause-transact-sql?view=sql-server-ver17#syntax +class XMLKeyValueOption(Expression): + arg_types = {"this": True, "expression": False} + + +class Year(Func): + pass + + +class Zipf(Func): + arg_types = {"this": True, "elementcount": True, "gen": True} + + +class Use(Expression): + arg_types = {"this": False, "expressions": False, "kind": False} + + +class Merge(DML): + arg_types = { + "this": True, + "using": True, + "on": False, + "using_cond": False, + "whens": True, + "with_": False, + "returning": False, + } + + +class When(Expression): + arg_types = {"matched": True, "source": False, "condition": False, "then": True} + + +class Whens(Expression): + """Wraps around one or more WHEN [NOT] MATCHED [...] clauses.""" + + arg_types = {"expressions": True} + + +# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html +# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16 +class NextValueFor(Func): + arg_types = {"this": True, "order": False} + + +# Refers to a trailing semi-colon. This is only used to preserve trailing comments +# select 1; -- my comment +class Semicolon(Expression): + arg_types = {} + + +# BigQuery allows SELECT t FROM t and treats the projection as a struct value. This expression +# type is intended to be constructed by qualify so that we can properly annotate its type later +class TableColumn(Expression): + pass + + +ALL_FUNCTIONS = subclasses(__name__, Func, {AggFunc, Anonymous, Func}) +FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()} + +JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, {JSONPathPart}) + +PERCENTILES = (PercentileCont, PercentileDisc) + + +# Helpers +@t.overload +def maybe_parse( + sql_or_expression: ExpOrStr, + *, + into: t.Type[E], + dialect: DialectType = None, + prefix: t.Optional[str] = None, + copy: bool = False, + **opts, +) -> E: + ... + + +@t.overload +def maybe_parse( + sql_or_expression: str | E, + *, + into: t.Optional[IntoType] = None, + dialect: DialectType = None, + prefix: t.Optional[str] = None, + copy: bool = False, + **opts, +) -> E: + ... + + +def maybe_parse( + sql_or_expression: ExpOrStr, + *, + into: t.Optional[IntoType] = None, + dialect: DialectType = None, + prefix: t.Optional[str] = None, + copy: bool = False, + **opts, +) -> Expression: + """Gracefully handle a possible string or expression. + + Example: + >>> maybe_parse("1") + Literal(this=1, is_string=False) + >>> maybe_parse(to_identifier("x")) + Identifier(this=x, quoted=False) + + Args: + sql_or_expression: the SQL code string or an expression + into: the SQLGlot Expression to parse into + dialect: the dialect used to parse the input expressions (in the case that an + input expression is a SQL string). + prefix: a string to prefix the sql with before it gets parsed + (automatically includes a space) + copy: whether to copy the expression. + **opts: other options to use to parse the input expressions (again, in the case + that an input expression is a SQL string). + + Returns: + Expression: the parsed or given expression. + """ + if isinstance(sql_or_expression, Expression): + if copy: + return sql_or_expression.copy() + return sql_or_expression + + if sql_or_expression is None: + raise ParseError("SQL cannot be None") + + import bigframes_vendored.sqlglot + + sql = str(sql_or_expression) + if prefix: + sql = f"{prefix} {sql}" + + return bigframes_vendored.sqlglot.parse_one(sql, read=dialect, into=into, **opts) + + +@t.overload +def maybe_copy(instance: None, copy: bool = True) -> None: + ... + + +@t.overload +def maybe_copy(instance: E, copy: bool = True) -> E: + ... + + +def maybe_copy(instance, copy=True): + return instance.copy() if copy and instance else instance + + +def _to_s( + node: t.Any, verbose: bool = False, level: int = 0, repr_str: bool = False +) -> str: + """Generate a textual representation of an Expression tree""" + indent = "\n" + (" " * (level + 1)) + delim = f",{indent}" + + if isinstance(node, Expression): + args = { + k: v for k, v in node.args.items() if (v is not None and v != []) or verbose + } + + if (node.type or verbose) and not isinstance(node, DataType): + args["_type"] = node.type + if node.comments or verbose: + args["_comments"] = node.comments + + if verbose: + args["_id"] = id(node) + + # Inline leaves for a more compact representation + if node.is_leaf(): + indent = "" + delim = ", " + + repr_str = node.is_string or (isinstance(node, Identifier) and node.quoted) + items = delim.join( + [ + f"{k}={_to_s(v, verbose, level + 1, repr_str=repr_str)}" + for k, v in args.items() + ] + ) + return f"{node.__class__.__name__}({indent}{items})" + + if isinstance(node, list): + items = delim.join(_to_s(i, verbose, level + 1) for i in node) + items = f"{indent}{items}" if items else "" + return f"[{items}]" + + # We use the representation of the string to avoid stripping out important whitespace + if repr_str and isinstance(node, str): + node = repr(node) + + # Indent multiline strings to match the current level + return indent.join(textwrap.dedent(str(node).strip("\n")).splitlines()) + + +def _is_wrong_expression(expression, into): + return isinstance(expression, Expression) and not isinstance(expression, into) + + +def _apply_builder( + expression, + instance, + arg, + copy=True, + prefix=None, + into=None, + dialect=None, + into_arg="this", + **opts, +): + if _is_wrong_expression(expression, into): + expression = into(**{into_arg: expression}) + instance = maybe_copy(instance, copy) + expression = maybe_parse( + sql_or_expression=expression, + prefix=prefix, + into=into, + dialect=dialect, + **opts, + ) + instance.set(arg, expression) + return instance + + +def _apply_child_list_builder( + *expressions, + instance, + arg, + append=True, + copy=True, + prefix=None, + into=None, + dialect=None, + properties=None, + **opts, +): + instance = maybe_copy(instance, copy) + parsed = [] + properties = {} if properties is None else properties + + for expression in expressions: + if expression is not None: + if _is_wrong_expression(expression, into): + expression = into(expressions=[expression]) + + expression = maybe_parse( + expression, + into=into, + dialect=dialect, + prefix=prefix, + **opts, + ) + for k, v in expression.args.items(): + if k == "expressions": + parsed.extend(v) + else: + properties[k] = v + + existing = instance.args.get(arg) + if append and existing: + parsed = existing.expressions + parsed + + child = into(expressions=parsed) + for k, v in properties.items(): + child.set(k, v) + instance.set(arg, child) + + return instance + + +def _apply_list_builder( + *expressions, + instance, + arg, + append=True, + copy=True, + prefix=None, + into=None, + dialect=None, + **opts, +): + inst = maybe_copy(instance, copy) + + expressions = [ + maybe_parse( + sql_or_expression=expression, + into=into, + prefix=prefix, + dialect=dialect, + **opts, + ) + for expression in expressions + if expression is not None + ] + + existing_expressions = inst.args.get(arg) + if append and existing_expressions: + expressions = existing_expressions + expressions + + inst.set(arg, expressions) + return inst + + +def _apply_conjunction_builder( + *expressions, + instance, + arg, + into=None, + append=True, + copy=True, + dialect=None, + **opts, +): + expressions = [exp for exp in expressions if exp is not None and exp != ""] + if not expressions: + return instance + + inst = maybe_copy(instance, copy) + + existing = inst.args.get(arg) + if append and existing is not None: + expressions = [existing.this if into else existing] + list(expressions) + + node = and_(*expressions, dialect=dialect, copy=copy, **opts) + + inst.set(arg, into(this=node) if into else node) + return inst + + +def _apply_cte_builder( + instance: E, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + materialized: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + scalar: t.Optional[bool] = None, + **opts, +) -> E: + alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts) + as_expression = maybe_parse(as_, dialect=dialect, copy=copy, **opts) + if scalar and not isinstance(as_expression, Subquery): + # scalar CTE must be wrapped in a subquery + as_expression = Subquery(this=as_expression) + cte = CTE( + this=as_expression, + alias=alias_expression, + materialized=materialized, + scalar=scalar, + ) + return _apply_child_list_builder( + cte, + instance=instance, + arg="with_", + append=append, + copy=copy, + into=With, + properties={"recursive": recursive} if recursive else {}, + ) + + +def _combine( + expressions: t.Sequence[t.Optional[ExpOrStr]], + operator: t.Type[Connector], + dialect: DialectType = None, + copy: bool = True, + wrap: bool = True, + **opts, +) -> Expression: + conditions = [ + condition(expression, dialect=dialect, copy=copy, **opts) + for expression in expressions + if expression is not None + ] + + this, *rest = conditions + if rest and wrap: + this = _wrap(this, Connector) + for expression in rest: + this = operator( + this=this, expression=_wrap(expression, Connector) if wrap else expression + ) + + return this + + +@t.overload +def _wrap(expression: None, kind: t.Type[Expression]) -> None: + ... + + +@t.overload +def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: + ... + + +def _wrap(expression: t.Optional[E], kind: t.Type[Expression]) -> t.Optional[E] | Paren: + return Paren(this=expression) if isinstance(expression, kind) else expression + + +def _apply_set_operation( + *expressions: ExpOrStr, + set_operation: t.Type[S], + distinct: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> S: + return reduce( + lambda x, y: set_operation(this=x, expression=y, distinct=distinct, **opts), + (maybe_parse(e, dialect=dialect, copy=copy, **opts) for e in expressions), + ) + + +def union( + *expressions: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Union: + """ + Initializes a syntax tree for the `UNION` operation. + + Example: + >>> union("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo UNION SELECT * FROM bla' + + Args: + expressions: the SQL code strings, corresponding to the `UNION`'s operands. + If `Expression` instances are passed, they will be used as-is. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + copy: whether to copy the expression. + opts: other options to use to parse the input expressions. + + Returns: + The new Union instance. + """ + assert len(expressions) >= 2, "At least two expressions are required by `union`." + return _apply_set_operation( + *expressions, + set_operation=Union, + distinct=distinct, + dialect=dialect, + copy=copy, + **opts, + ) + + +def intersect( + *expressions: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Intersect: + """ + Initializes a syntax tree for the `INTERSECT` operation. + + Example: + >>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo INTERSECT SELECT * FROM bla' + + Args: + expressions: the SQL code strings, corresponding to the `INTERSECT`'s operands. + If `Expression` instances are passed, they will be used as-is. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + copy: whether to copy the expression. + opts: other options to use to parse the input expressions. + + Returns: + The new Intersect instance. + """ + assert ( + len(expressions) >= 2 + ), "At least two expressions are required by `intersect`." + return _apply_set_operation( + *expressions, + set_operation=Intersect, + distinct=distinct, + dialect=dialect, + copy=copy, + **opts, + ) + + +def except_( + *expressions: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Except: + """ + Initializes a syntax tree for the `EXCEPT` operation. + + Example: + >>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo EXCEPT SELECT * FROM bla' + + Args: + expressions: the SQL code strings, corresponding to the `EXCEPT`'s operands. + If `Expression` instances are passed, they will be used as-is. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + copy: whether to copy the expression. + opts: other options to use to parse the input expressions. + + Returns: + The new Except instance. + """ + assert len(expressions) >= 2, "At least two expressions are required by `except_`." + return _apply_set_operation( + *expressions, + set_operation=Except, + distinct=distinct, + dialect=dialect, + copy=copy, + **opts, + ) + + +def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select: + """ + Initializes a syntax tree from one or multiple SELECT expressions. + + Example: + >>> select("col1", "col2").from_("tbl").sql() + 'SELECT col1, col2 FROM tbl' + + Args: + *expressions: the SQL code string to parse as the expressions of a + SELECT statement. If an Expression instance is passed, this is used as-is. + dialect: the dialect used to parse the input expressions (in the case that an + input expression is a SQL string). + **opts: other options to use to parse the input expressions (again, in the case + that an input expression is a SQL string). + + Returns: + Select: the syntax tree for the SELECT statement. + """ + return Select().select(*expressions, dialect=dialect, **opts) + + +def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select: + """ + Initializes a syntax tree from a FROM expression. + + Example: + >>> from_("tbl").select("col1", "col2").sql() + 'SELECT col1, col2 FROM tbl' + + Args: + *expression: the SQL code string to parse as the FROM expressions of a + SELECT statement. If an Expression instance is passed, this is used as-is. + dialect: the dialect used to parse the input expression (in the case that the + input expression is a SQL string). + **opts: other options to use to parse the input expressions (again, in the case + that the input expression is a SQL string). + + Returns: + Select: the syntax tree for the SELECT statement. + """ + return Select().from_(expression, dialect=dialect, **opts) + + +def update( + table: str | Table, + properties: t.Optional[dict] = None, + where: t.Optional[ExpOrStr] = None, + from_: t.Optional[ExpOrStr] = None, + with_: t.Optional[t.Dict[str, ExpOrStr]] = None, + dialect: DialectType = None, + **opts, +) -> Update: + """ + Creates an update statement. + + Example: + >>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz_cte", where="baz_cte.id > 1 and my_table.id = baz_cte.id", with_={"baz_cte": "SELECT id FROM foo"}).sql() + "WITH baz_cte AS (SELECT id FROM foo) UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz_cte WHERE baz_cte.id > 1 AND my_table.id = baz_cte.id" + + Args: + properties: dictionary of properties to SET which are + auto converted to sql objects eg None -> NULL + where: sql conditional parsed into a WHERE statement + from_: sql statement parsed into a FROM statement + with_: dictionary of CTE aliases / select statements to include in a WITH clause. + dialect: the dialect used to parse the input expressions. + **opts: other options to use to parse the input expressions. + + Returns: + Update: the syntax tree for the UPDATE statement. + """ + update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect)) + if properties: + update_expr.set( + "expressions", + [ + EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) + for k, v in properties.items() + ], + ) + if from_: + update_expr.set( + "from_", + maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), + ) + if isinstance(where, Condition): + where = Where(this=where) + if where: + update_expr.set( + "where", + maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), + ) + if with_: + cte_list = [ + alias_( + CTE(this=maybe_parse(qry, dialect=dialect, **opts)), alias, table=True + ) + for alias, qry in with_.items() + ] + update_expr.set( + "with_", + With(expressions=cte_list), + ) + return update_expr + + +def delete( + table: ExpOrStr, + where: t.Optional[ExpOrStr] = None, + returning: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + **opts, +) -> Delete: + """ + Builds a delete statement. + + Example: + >>> delete("my_table", where="id > 1").sql() + 'DELETE FROM my_table WHERE id > 1' + + Args: + where: sql conditional parsed into a WHERE statement + returning: sql conditional parsed into a RETURNING statement + dialect: the dialect used to parse the input expressions. + **opts: other options to use to parse the input expressions. + + Returns: + Delete: the syntax tree for the DELETE statement. + """ + delete_expr = Delete().delete(table, dialect=dialect, copy=False, **opts) + if where: + delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts) + if returning: + delete_expr = delete_expr.returning( + returning, dialect=dialect, copy=False, **opts + ) + return delete_expr + + +def insert( + expression: ExpOrStr, + into: ExpOrStr, + columns: t.Optional[t.Sequence[str | Identifier]] = None, + overwrite: t.Optional[bool] = None, + returning: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Insert: + """ + Builds an INSERT statement. + + Example: + >>> insert("VALUES (1, 2, 3)", "tbl").sql() + 'INSERT INTO tbl VALUES (1, 2, 3)' + + Args: + expression: the sql string or expression of the INSERT statement + into: the tbl to insert data to. + columns: optionally the table's column names. + overwrite: whether to INSERT OVERWRITE or not. + returning: sql conditional parsed into a RETURNING statement + dialect: the dialect used to parse the input expressions. + copy: whether to copy the expression. + **opts: other options to use to parse the input expressions. + + Returns: + Insert: the syntax tree for the INSERT statement. + """ + expr = maybe_parse(expression, dialect=dialect, copy=copy, **opts) + this: Table | Schema = maybe_parse( + into, into=Table, dialect=dialect, copy=copy, **opts + ) + + if columns: + this = Schema( + this=this, expressions=[to_identifier(c, copy=copy) for c in columns] + ) + + insert = Insert(this=this, expression=expr, overwrite=overwrite) + + if returning: + insert = insert.returning(returning, dialect=dialect, copy=False, **opts) + + return insert + + +def merge( + *when_exprs: ExpOrStr, + into: ExpOrStr, + using: ExpOrStr, + on: ExpOrStr, + returning: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Merge: + """ + Builds a MERGE statement. + + Example: + >>> merge("WHEN MATCHED THEN UPDATE SET col1 = source_table.col1", + ... "WHEN NOT MATCHED THEN INSERT (col1) VALUES (source_table.col1)", + ... into="my_table", + ... using="source_table", + ... on="my_table.id = source_table.id").sql() + 'MERGE INTO my_table USING source_table ON my_table.id = source_table.id WHEN MATCHED THEN UPDATE SET col1 = source_table.col1 WHEN NOT MATCHED THEN INSERT (col1) VALUES (source_table.col1)' + + Args: + *when_exprs: The WHEN clauses specifying actions for matched and unmatched rows. + into: The target table to merge data into. + using: The source table to merge data from. + on: The join condition for the merge. + returning: The columns to return from the merge. + dialect: The dialect used to parse the input expressions. + copy: Whether to copy the expression. + **opts: Other options to use to parse the input expressions. + + Returns: + Merge: The syntax tree for the MERGE statement. + """ + expressions: t.List[Expression] = [] + for when_expr in when_exprs: + expression = maybe_parse( + when_expr, dialect=dialect, copy=copy, into=Whens, **opts + ) + expressions.extend( + [expression] if isinstance(expression, When) else expression.expressions + ) + + merge = Merge( + this=maybe_parse(into, dialect=dialect, copy=copy, **opts), + using=maybe_parse(using, dialect=dialect, copy=copy, **opts), + on=maybe_parse(on, dialect=dialect, copy=copy, **opts), + whens=Whens(expressions=expressions), + ) + if returning: + merge = merge.returning(returning, dialect=dialect, copy=False, **opts) + + if isinstance(using_clause := merge.args.get("using"), Alias): + using_clause.replace( + alias_(using_clause.this, using_clause.args["alias"], table=True) + ) + + return merge + + +def condition( + expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts +) -> Condition: + """ + Initialize a logical condition expression. + + Example: + >>> condition("x=1").sql() + 'x = 1' + + This is helpful for composing larger logical syntax trees: + >>> where = condition("x=1") + >>> where = where.and_("y=1") + >>> Select().from_("tbl").select("*").where(where).sql() + 'SELECT * FROM tbl WHERE x = 1 AND y = 1' + + Args: + *expression: the SQL code string to parse. + If an Expression instance is passed, this is used as-is. + dialect: the dialect used to parse the input expression (in the case that the + input expression is a SQL string). + copy: Whether to copy `expression` (only applies to expressions). + **opts: other options to use to parse the input expressions (again, in the case + that the input expression is a SQL string). + + Returns: + The new Condition instance + """ + return maybe_parse( + expression, + into=Condition, + dialect=dialect, + copy=copy, + **opts, + ) + + +def and_( + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + wrap: bool = True, + **opts, +) -> Condition: + """ + Combine multiple conditions with an AND logical operator. + + Example: + >>> and_("x=1", and_("y=1", "z=1")).sql() + 'x = 1 AND (y = 1 AND z = 1)' + + Args: + *expressions: the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect: the dialect used to parse the input expression. + copy: whether to copy `expressions` (only applies to Expressions). + wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid + precedence issues, but can be turned off when the produced AST is too deep and + causes recursion-related issues. + **opts: other options to use to parse the input expressions. + + Returns: + The new condition + """ + return t.cast( + Condition, _combine(expressions, And, dialect, copy=copy, wrap=wrap, **opts) + ) + + +def or_( + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + wrap: bool = True, + **opts, +) -> Condition: + """ + Combine multiple conditions with an OR logical operator. + + Example: + >>> or_("x=1", or_("y=1", "z=1")).sql() + 'x = 1 OR (y = 1 OR z = 1)' + + Args: + *expressions: the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect: the dialect used to parse the input expression. + copy: whether to copy `expressions` (only applies to Expressions). + wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid + precedence issues, but can be turned off when the produced AST is too deep and + causes recursion-related issues. + **opts: other options to use to parse the input expressions. + + Returns: + The new condition + """ + return t.cast( + Condition, _combine(expressions, Or, dialect, copy=copy, wrap=wrap, **opts) + ) + + +def xor( + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + wrap: bool = True, + **opts, +) -> Condition: + """ + Combine multiple conditions with an XOR logical operator. + + Example: + >>> xor("x=1", xor("y=1", "z=1")).sql() + 'x = 1 XOR (y = 1 XOR z = 1)' + + Args: + *expressions: the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect: the dialect used to parse the input expression. + copy: whether to copy `expressions` (only applies to Expressions). + wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid + precedence issues, but can be turned off when the produced AST is too deep and + causes recursion-related issues. + **opts: other options to use to parse the input expressions. + + Returns: + The new condition + """ + return t.cast( + Condition, _combine(expressions, Xor, dialect, copy=copy, wrap=wrap, **opts) + ) + + +def not_( + expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts +) -> Not: + """ + Wrap a condition with a NOT operator. + + Example: + >>> not_("this_suit='black'").sql() + "NOT this_suit = 'black'" + + Args: + expression: the SQL code string to parse. + If an Expression instance is passed, this is used as-is. + dialect: the dialect used to parse the input expression. + copy: whether to copy the expression or not. + **opts: other options to use to parse the input expressions. + + Returns: + The new condition. + """ + this = condition( + expression, + dialect=dialect, + copy=copy, + **opts, + ) + return Not(this=_wrap(this, Connector)) + + +def paren(expression: ExpOrStr, copy: bool = True) -> Paren: + """ + Wrap an expression in parentheses. + + Example: + >>> paren("5 + 3").sql() + '(5 + 3)' + + Args: + expression: the SQL code string to parse. + If an Expression instance is passed, this is used as-is. + copy: whether to copy the expression or not. + + Returns: + The wrapped expression. + """ + return Paren(this=maybe_parse(expression, copy=copy)) + + +SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$") + + +@t.overload +def to_identifier( + name: None, quoted: t.Optional[bool] = None, copy: bool = True +) -> None: + ... + + +@t.overload +def to_identifier( + name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True +) -> Identifier: + ... + + +def to_identifier(name, quoted=None, copy=True): + """Builds an identifier. + + Args: + name: The name to turn into an identifier. + quoted: Whether to force quote the identifier. + copy: Whether to copy name if it's an Identifier. + + Returns: + The identifier ast node. + """ + + if name is None: + return None + + if isinstance(name, Identifier): + identifier = maybe_copy(name, copy) + elif isinstance(name, str): + identifier = Identifier( + this=name, + quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted, + ) + else: + raise ValueError( + f"Name needs to be a string or an Identifier, got: {name.__class__}" + ) + return identifier + + +def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Identifier: + """ + Parses a given string into an identifier. + + Args: + name: The name to parse into an identifier. + dialect: The dialect to parse against. + + Returns: + The identifier ast node. + """ + try: + expression = maybe_parse(name, dialect=dialect, into=Identifier) + except (ParseError, TokenError): + expression = to_identifier(name) + + return expression + + +INTERVAL_STRING_RE = re.compile(r"\s*(-?[0-9]+(?:\.[0-9]+)?)\s*([a-zA-Z]+)\s*") + +# Matches day-time interval strings that contain +# - A number of days (possibly negative or with decimals) +# - At least one space +# - Portions of a time-like signature, potentially negative +# - Standard format [-]h+:m+:s+[.f+] +# - Just minutes/seconds/frac seconds [-]m+:s+.f+ +# - Just hours, minutes, maybe colon [-]h+:m+[:] +# - Just hours, maybe colon [-]h+[:] +# - Just colon : +INTERVAL_DAY_TIME_RE = re.compile( + r"\s*-?\s*\d+(?:\.\d+)?\s+(?:-?(?:\d+:)?\d+:\d+(?:\.\d+)?|-?(?:\d+:){1,2}|:)\s*" +) + + +def to_interval(interval: str | Literal) -> Interval: + """Builds an interval expression from a string like '1 day' or '5 months'.""" + if isinstance(interval, Literal): + if not interval.is_string: + raise ValueError("Invalid interval string.") + + interval = interval.this + + interval = maybe_parse(f"INTERVAL {interval}") + assert isinstance(interval, Interval) + return interval + + +def to_table( + sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs +) -> Table: + """ + Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. + If a table is passed in then that table is returned. + + Args: + sql_path: a `[catalog].[schema].[table]` string. + dialect: the source dialect according to which the table name will be parsed. + copy: Whether to copy a table if it is passed in. + kwargs: the kwargs to instantiate the resulting `Table` expression with. + + Returns: + A table expression. + """ + if isinstance(sql_path, Table): + return maybe_copy(sql_path, copy=copy) + + try: + table = maybe_parse(sql_path, into=Table, dialect=dialect) + except ParseError: + catalog, db, this = split_num_words(sql_path, ".", 3) + + if not this: + raise + + table = table_(this, db=db, catalog=catalog) + + for k, v in kwargs.items(): + table.set(k, v) + + return table + + +def to_column( + sql_path: str | Column, + quoted: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **kwargs, +) -> Column: + """ + Create a column from a `[table].[column]` sql path. Table is optional. + If a column is passed in then that column is returned. + + Args: + sql_path: a `[table].[column]` string. + quoted: Whether or not to force quote identifiers. + dialect: the source dialect according to which the column name will be parsed. + copy: Whether to copy a column if it is passed in. + kwargs: the kwargs to instantiate the resulting `Column` expression with. + + Returns: + A column expression. + """ + if isinstance(sql_path, Column): + return maybe_copy(sql_path, copy=copy) + + try: + col = maybe_parse(sql_path, into=Column, dialect=dialect) + except ParseError: + return column(*reversed(sql_path.split(".")), quoted=quoted, **kwargs) + + for k, v in kwargs.items(): + col.set(k, v) + + if quoted: + for i in col.find_all(Identifier): + i.set("quoted", True) + + return col + + +def alias_( + expression: ExpOrStr, + alias: t.Optional[str | Identifier], + table: bool | t.Sequence[str | Identifier] = False, + quoted: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, +): + """Create an Alias expression. + + Example: + >>> alias_('foo', 'bar').sql() + 'foo AS bar' + + >>> alias_('(select 1, 2)', 'bar', table=['a', 'b']).sql() + '(SELECT 1, 2) AS bar(a, b)' + + Args: + expression: the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + alias: the alias name to use. If the name has + special characters it is quoted. + table: Whether to create a table alias, can also be a list of columns. + quoted: whether to quote the alias + dialect: the dialect used to parse the input expression. + copy: Whether to copy the expression. + **opts: other options to use to parse the input expressions. + + Returns: + Alias: the aliased expression + """ + exp = maybe_parse(expression, dialect=dialect, copy=copy, **opts) + alias = to_identifier(alias, quoted=quoted) + + if table: + table_alias = TableAlias(this=alias) + exp.set("alias", table_alias) + + if not isinstance(table, bool): + for column in table: + table_alias.append("columns", to_identifier(column, quoted=quoted)) + + return exp + + # We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in + # the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node + # for the complete Window expression. + # + # [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls + + if "alias" in exp.arg_types and not isinstance(exp, Window): + exp.set("alias", alias) + return exp + return Alias(this=exp, alias=alias) + + +def subquery( + expression: ExpOrStr, + alias: t.Optional[Identifier | str] = None, + dialect: DialectType = None, + **opts, +) -> Select: + """ + Build a subquery expression that's selected from. + + Example: + >>> subquery('select x from tbl', 'bar').select('x').sql() + 'SELECT x FROM (SELECT x FROM tbl) AS bar' + + Args: + expression: the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + alias: the alias name to use. + dialect: the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + A new Select instance with the subquery expression included. + """ + + expression = maybe_parse(expression, dialect=dialect, **opts).subquery( + alias, **opts + ) + return Select().from_(expression, dialect=dialect, **opts) + + +@t.overload +def column( + col: str | Identifier, + table: t.Optional[str | Identifier] = None, + db: t.Optional[str | Identifier] = None, + catalog: t.Optional[str | Identifier] = None, + *, + fields: t.Collection[t.Union[str, Identifier]], + quoted: t.Optional[bool] = None, + copy: bool = True, +) -> Dot: + pass + + +@t.overload +def column( + col: str | Identifier | Star, + table: t.Optional[str | Identifier] = None, + db: t.Optional[str | Identifier] = None, + catalog: t.Optional[str | Identifier] = None, + *, + fields: Lit[None] = None, + quoted: t.Optional[bool] = None, + copy: bool = True, +) -> Column: + pass + + +def column( + col, + table=None, + db=None, + catalog=None, + *, + fields=None, + quoted=None, + copy=True, +): + """ + Build a Column. + + Args: + col: Column name. + table: Table name. + db: Database name. + catalog: Catalog name. + fields: Additional fields using dots. + quoted: Whether to force quotes on the column's identifiers. + copy: Whether to copy identifiers if passed in. + + Returns: + The new Column instance. + """ + if not isinstance(col, Star): + col = to_identifier(col, quoted=quoted, copy=copy) + + this = Column( + this=col, + table=to_identifier(table, quoted=quoted, copy=copy), + db=to_identifier(db, quoted=quoted, copy=copy), + catalog=to_identifier(catalog, quoted=quoted, copy=copy), + ) + + if fields: + this = Dot.build( + ( + this, + *(to_identifier(field, quoted=quoted, copy=copy) for field in fields), + ) + ) + return this + + +def cast( + expression: ExpOrStr, + to: DATA_TYPE, + copy: bool = True, + dialect: DialectType = None, + **opts, +) -> Cast: + """Cast an expression to a data type. + + Example: + >>> cast('x + 1', 'int').sql() + 'CAST(x + 1 AS INT)' + + Args: + expression: The expression to cast. + to: The datatype to cast to. + copy: Whether to copy the supplied expressions. + dialect: The target dialect. This is used to prevent a re-cast in the following scenario: + - The expression to be cast is already a exp.Cast expression + - The existing cast is to a type that is logically equivalent to new type + + For example, if :expression='CAST(x as DATETIME)' and :to=Type.TIMESTAMP, + but in the target dialect DATETIME is mapped to TIMESTAMP, then we will NOT return `CAST(x (as DATETIME) as TIMESTAMP)` + and instead just return the original expression `CAST(x as DATETIME)`. + + This is to prevent it being output as a double cast `CAST(x (as TIMESTAMP) as TIMESTAMP)` once the DATETIME -> TIMESTAMP + mapping is applied in the target dialect generator. + + Returns: + The new Cast instance. + """ + expr = maybe_parse(expression, copy=copy, dialect=dialect, **opts) + data_type = DataType.build(to, copy=copy, dialect=dialect, **opts) + + # dont re-cast if the expression is already a cast to the correct type + if isinstance(expr, Cast): + from bigframes_vendored.sqlglot.dialects.dialect import Dialect + + target_dialect = Dialect.get_or_raise(dialect) + type_mapping = target_dialect.generator_class.TYPE_MAPPING + + existing_cast_type: DataType.Type = expr.to.this + new_cast_type: DataType.Type = data_type.this + types_are_equivalent = type_mapping.get( + existing_cast_type, existing_cast_type.value + ) == type_mapping.get(new_cast_type, new_cast_type.value) + + if expr.is_type(data_type) or types_are_equivalent: + return expr + + expr = Cast(this=expr, to=data_type) + expr.type = data_type + + return expr + + +def table_( + table: Identifier | str, + db: t.Optional[Identifier | str] = None, + catalog: t.Optional[Identifier | str] = None, + quoted: t.Optional[bool] = None, + alias: t.Optional[Identifier | str] = None, +) -> Table: + """Build a Table. + + Args: + table: Table name. + db: Database name. + catalog: Catalog name. + quote: Whether to force quotes on the table's identifiers. + alias: Table's alias. + + Returns: + The new Table instance. + """ + return Table( + this=to_identifier(table, quoted=quoted) if table else None, + db=to_identifier(db, quoted=quoted) if db else None, + catalog=to_identifier(catalog, quoted=quoted) if catalog else None, + alias=TableAlias(this=to_identifier(alias)) if alias else None, + ) + + +def values( + values: t.Iterable[t.Tuple[t.Any, ...]], + alias: t.Optional[str] = None, + columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None, +) -> Values: + """Build VALUES statement. + + Example: + >>> values([(1, '2')]).sql() + "VALUES (1, '2')" + + Args: + values: values statements that will be converted to SQL + alias: optional alias + columns: Optional list of ordered column names or ordered dictionary of column names to types. + If either are provided then an alias is also required. + + Returns: + Values: the Values expression object + """ + if columns and not alias: + raise ValueError("Alias is required when providing columns") + + return Values( + expressions=[convert(tup) for tup in values], + alias=( + TableAlias( + this=to_identifier(alias), columns=[to_identifier(x) for x in columns] + ) + if columns + else (TableAlias(this=to_identifier(alias)) if alias else None) + ), + ) + + +def var(name: t.Optional[ExpOrStr]) -> Var: + """Build a SQL variable. + + Example: + >>> repr(var('x')) + 'Var(this=x)' + + >>> repr(var(column('x', table='y'))) + 'Var(this=x)' + + Args: + name: The name of the var or an expression who's name will become the var. + + Returns: + The new variable node. + """ + if not name: + raise ValueError("Cannot convert empty name into var.") + + if isinstance(name, Expression): + name = name.name + return Var(this=name) + + +def rename_table( + old_name: str | Table, + new_name: str | Table, + dialect: DialectType = None, +) -> Alter: + """Build ALTER TABLE... RENAME... expression + + Args: + old_name: The old name of the table + new_name: The new name of the table + dialect: The dialect to parse the table. + + Returns: + Alter table expression + """ + old_table = to_table(old_name, dialect=dialect) + new_table = to_table(new_name, dialect=dialect) + return Alter( + this=old_table, + kind="TABLE", + actions=[ + AlterRename(this=new_table), + ], + ) + + +def rename_column( + table_name: str | Table, + old_column_name: str | Column, + new_column_name: str | Column, + exists: t.Optional[bool] = None, + dialect: DialectType = None, +) -> Alter: + """Build ALTER TABLE... RENAME COLUMN... expression + + Args: + table_name: Name of the table + old_column: The old name of the column + new_column: The new name of the column + exists: Whether to add the `IF EXISTS` clause + dialect: The dialect to parse the table/column. + + Returns: + Alter table expression + """ + table = to_table(table_name, dialect=dialect) + old_column = to_column(old_column_name, dialect=dialect) + new_column = to_column(new_column_name, dialect=dialect) + return Alter( + this=table, + kind="TABLE", + actions=[ + RenameColumn(this=old_column, to=new_column, exists=exists), + ], + ) + + +def convert(value: t.Any, copy: bool = False) -> Expression: + """Convert a python value into an expression object. + + Raises an error if a conversion is not possible. + + Args: + value: A python object. + copy: Whether to copy `value` (only applies to Expressions and collections). + + Returns: + The equivalent expression object. + """ + if isinstance(value, Expression): + return maybe_copy(value, copy) + if isinstance(value, str): + return Literal.string(value) + if isinstance(value, bool): + return Boolean(this=value) + if value is None or (isinstance(value, float) and math.isnan(value)): + return null() + if isinstance(value, numbers.Number): + return Literal.number(value) + if isinstance(value, bytes): + return HexString(this=value.hex()) + if isinstance(value, datetime.datetime): + datetime_literal = Literal.string(value.isoformat(sep=" ")) + + tz = None + if value.tzinfo: + # this works for zoneinfo.ZoneInfo, pytz.timezone and datetime.datetime.utc to return IANA timezone names like "America/Los_Angeles" + # instead of abbreviations like "PDT". This is for consistency with other timezone handling functions in SQLGlot + tz = Literal.string(str(value.tzinfo)) + + return TimeStrToTime(this=datetime_literal, zone=tz) + if isinstance(value, datetime.date): + date_literal = Literal.string(value.strftime("%Y-%m-%d")) + return DateStrToDate(this=date_literal) + if isinstance(value, datetime.time): + time_literal = Literal.string(value.isoformat()) + return TsOrDsToTime(this=time_literal) + if isinstance(value, tuple): + if hasattr(value, "_fields"): + return Struct( + expressions=[ + PropertyEQ( + this=to_identifier(k), + expression=convert(getattr(value, k), copy=copy), + ) + for k in value._fields + ] + ) + return Tuple(expressions=[convert(v, copy=copy) for v in value]) + if isinstance(value, list): + return Array(expressions=[convert(v, copy=copy) for v in value]) + if isinstance(value, dict): + return Map( + keys=Array(expressions=[convert(k, copy=copy) for k in value]), + values=Array(expressions=[convert(v, copy=copy) for v in value.values()]), + ) + if hasattr(value, "__dict__"): + return Struct( + expressions=[ + PropertyEQ(this=to_identifier(k), expression=convert(v, copy=copy)) + for k, v in value.__dict__.items() + ] + ) + raise ValueError(f"Cannot convert {value}") + + +def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -> None: + """ + Replace children of an expression with the result of a lambda fun(child) -> exp. + """ + for k, v in tuple(expression.args.items()): + is_list_arg = type(v) is list + + child_nodes = v if is_list_arg else [v] + new_child_nodes = [] + + for cn in child_nodes: + if isinstance(cn, Expression): + for child_node in ensure_collection(fun(cn, *args, **kwargs)): + new_child_nodes.append(child_node) + else: + new_child_nodes.append(cn) + + expression.set( + k, new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) + ) + + +def replace_tree( + expression: Expression, + fun: t.Callable, + prune: t.Optional[t.Callable[[Expression], bool]] = None, +) -> Expression: + """ + Replace an entire tree with the result of function calls on each node. + + This will be traversed in reverse dfs, so leaves first. + If new nodes are created as a result of function calls, they will also be traversed. + """ + stack = list(expression.dfs(prune=prune)) + + while stack: + node = stack.pop() + new_node = fun(node) + + if new_node is not node: + node.replace(new_node) + + if isinstance(new_node, Expression): + stack.append(new_node) + + return new_node + + +def find_tables(expression: Expression) -> t.Set[Table]: + """ + Find all tables referenced in a query. + + Args: + expressions: The query to find the tables in. + + Returns: + A set of all the tables. + """ + from bigframes_vendored.sqlglot.optimizer.scope import traverse_scope + + return { + table + for scope in traverse_scope(expression) + for table in scope.tables + if table.name and table.name not in scope.cte_sources + } + + +def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]: + """ + Return all table names referenced through columns in an expression. + + Example: + >>> import sqlglot + >>> sorted(column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))) + ['a', 'c'] + + Args: + expression: expression to find table names. + exclude: a table name to exclude + + Returns: + A list of unique names. + """ + return { + table + for table in (column.table for column in expression.find_all(Column)) + if table and table != exclude + } + + +def table_name( + table: Table | str, dialect: DialectType = None, identify: bool = False +) -> str: + """Get the full name of a table as a string. + + Args: + table: Table expression node or string. + dialect: The dialect to generate the table name for. + identify: Determines when an identifier should be quoted. Possible values are: + False (default): Never quote, except in cases where it's mandatory by the dialect. + True: Always quote. + + Examples: + >>> from sqlglot import exp, parse_one + >>> table_name(parse_one("select * from a.b.c").find(exp.Table)) + 'a.b.c' + + Returns: + The table name. + """ + + table = maybe_parse(table, into=Table, dialect=dialect) + + if not table: + raise ValueError(f"Cannot parse {table}") + + return ".".join( + ( + part.sql(dialect=dialect, identify=True, copy=False, comments=False) + if identify or not SAFE_IDENTIFIER_RE.match(part.name) + else part.name + ) + for part in table.parts + ) + + +def normalize_table_name( + table: str | Table, dialect: DialectType = None, copy: bool = True +) -> str: + """Returns a case normalized table name without quotes. + + Args: + table: the table to normalize + dialect: the dialect to use for normalization rules + copy: whether to copy the expression. + + Examples: + >>> normalize_table_name("`A-B`.c", dialect="bigquery") + 'A-B.c' + """ + from bigframes_vendored.sqlglot.optimizer.normalize_identifiers import ( + normalize_identifiers, + ) + + return ".".join( + p.name + for p in normalize_identifiers( + to_table(table, dialect=dialect, copy=copy), dialect=dialect + ).parts + ) + + +def replace_tables( + expression: E, + mapping: t.Dict[str, str], + dialect: DialectType = None, + copy: bool = True, +) -> E: + """Replace all tables in expression according to the mapping. + + Args: + expression: expression node to be transformed and replaced. + mapping: mapping of table names. + dialect: the dialect of the mapping table + copy: whether to copy the expression. + + Examples: + >>> from sqlglot import exp, parse_one + >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql() + 'SELECT * FROM c /* a.b */' + + Returns: + The mapped expression. + """ + + mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()} + + def _replace_tables(node: Expression) -> Expression: + if isinstance(node, Table) and node.meta.get("replace") is not False: + original = normalize_table_name(node, dialect=dialect) + new_name = mapping.get(original) + + if new_name: + table = to_table( + new_name, + **{k: v for k, v in node.args.items() if k not in TABLE_PARTS}, + dialect=dialect, + ) + table.add_comments([original]) + return table + return node + + return expression.transform(_replace_tables, copy=copy) # type: ignore + + +def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: + """Replace placeholders in an expression. + + Args: + expression: expression node to be transformed and replaced. + args: positional names that will substitute unnamed placeholders in the given order. + kwargs: keyword arguments that will substitute named placeholders. + + Examples: + >>> from sqlglot import exp, parse_one + >>> replace_placeholders( + ... parse_one("select * from :tbl where ? = ?"), + ... exp.to_identifier("str_col"), "b", tbl=exp.to_identifier("foo") + ... ).sql() + "SELECT * FROM foo WHERE str_col = 'b'" + + Returns: + The mapped expression. + """ + + def _replace_placeholders(node: Expression, args, **kwargs) -> Expression: + if isinstance(node, Placeholder): + if node.this: + new_name = kwargs.get(node.this) + if new_name is not None: + return convert(new_name) + else: + try: + return convert(next(args)) + except StopIteration: + pass + return node + + return expression.transform(_replace_placeholders, iter(args), **kwargs) + + +def expand( + expression: Expression, + sources: t.Dict[str, Query | t.Callable[[], Query]], + dialect: DialectType = None, + copy: bool = True, +) -> Expression: + """Transforms an expression by expanding all referenced sources into subqueries. + + Examples: + >>> from sqlglot import parse_one + >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql() + 'SELECT * FROM (SELECT * FROM y) AS z /* source: x */' + + >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql() + 'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */' + + Args: + expression: The expression to expand. + sources: A dict of name to query or a callable that provides a query on demand. + dialect: The dialect of the sources dict or the callable. + copy: Whether to copy the expression during transformation. Defaults to True. + + Returns: + The transformed expression. + """ + normalized_sources = { + normalize_table_name(k, dialect=dialect): v for k, v in sources.items() + } + + def _expand(node: Expression): + if isinstance(node, Table): + name = normalize_table_name(node, dialect=dialect) + source = normalized_sources.get(name) + + if source: + # Create a subquery with the same alias (or table name if no alias) + parsed_source = source() if callable(source) else source + subquery = parsed_source.subquery(node.alias or name) + subquery.comments = [f"source: {name}"] + + # Continue expanding within the subquery + return subquery.transform(_expand, copy=False) + + return node + + return expression.transform(_expand, copy=copy) + + +def func( + name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs +) -> Func: + """ + Returns a Func expression. + + Examples: + >>> func("abs", 5).sql() + 'ABS(5)' + + >>> func("cast", this=5, to=DataType.build("DOUBLE")).sql() + 'CAST(5 AS DOUBLE)' + + Args: + name: the name of the function to build. + args: the args used to instantiate the function of interest. + copy: whether to copy the argument expressions. + dialect: the source dialect. + kwargs: the kwargs used to instantiate the function of interest. + + Note: + The arguments `args` and `kwargs` are mutually exclusive. + + Returns: + An instance of the function of interest, or an anonymous function, if `name` doesn't + correspond to an existing `sqlglot.expressions.Func` class. + """ + if args and kwargs: + raise ValueError("Can't use both args and kwargs to instantiate a function.") + + from bigframes_vendored.sqlglot.dialects.dialect import Dialect + + dialect = Dialect.get_or_raise(dialect) + + converted: t.List[Expression] = [ + maybe_parse(arg, dialect=dialect, copy=copy) for arg in args + ] + kwargs = { + key: maybe_parse(value, dialect=dialect, copy=copy) + for key, value in kwargs.items() + } + + constructor = dialect.parser_class.FUNCTIONS.get(name.upper()) + if constructor: + if converted: + if "dialect" in constructor.__code__.co_varnames: + function = constructor(converted, dialect=dialect) + else: + function = constructor(converted) + elif constructor.__name__ == "from_arg_list": + function = constructor.__self__(**kwargs) # type: ignore + else: + constructor = FUNCTION_BY_NAME.get(name.upper()) + if constructor: + function = constructor(**kwargs) + else: + raise ValueError( + f"Unable to convert '{name}' into a Func. Either manually construct " + "the Func expression of interest or parse the function call." + ) + else: + kwargs = kwargs or {"expressions": converted} + function = Anonymous(this=name, **kwargs) + + for error_message in function.error_messages(converted): + raise ValueError(error_message) + + return function + + +def case( + expression: t.Optional[ExpOrStr] = None, + **opts, +) -> Case: + """ + Initialize a CASE statement. + + Example: + case().when("a = 1", "foo").else_("bar") + + Args: + expression: Optionally, the input expression (not all dialects support this) + **opts: Extra keyword arguments for parsing `expression` + """ + if expression is not None: + this = maybe_parse(expression, **opts) + else: + this = None + return Case(this=this, ifs=[]) + + +def array( + *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs +) -> Array: + """ + Returns an array. + + Examples: + >>> array(1, 'x').sql() + 'ARRAY(1, x)' + + Args: + expressions: the expressions to add to the array. + copy: whether to copy the argument expressions. + dialect: the source dialect. + kwargs: the kwargs used to instantiate the function of interest. + + Returns: + An array expression. + """ + return Array( + expressions=[ + maybe_parse(expression, copy=copy, dialect=dialect, **kwargs) + for expression in expressions + ] + ) + + +def tuple_( + *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs +) -> Tuple: + """ + Returns an tuple. + + Examples: + >>> tuple_(1, 'x').sql() + '(1, x)' + + Args: + expressions: the expressions to add to the tuple. + copy: whether to copy the argument expressions. + dialect: the source dialect. + kwargs: the kwargs used to instantiate the function of interest. + + Returns: + A tuple expression. + """ + return Tuple( + expressions=[ + maybe_parse(expression, copy=copy, dialect=dialect, **kwargs) + for expression in expressions + ] + ) + + +def true() -> Boolean: + """ + Returns a true Boolean expression. + """ + return Boolean(this=True) + + +def false() -> Boolean: + """ + Returns a false Boolean expression. + """ + return Boolean(this=False) + + +def null() -> Null: + """ + Returns a Null expression. + """ + return Null() + + +NONNULL_CONSTANTS = ( + Literal, + Boolean, +) + +CONSTANTS = ( + Literal, + Boolean, + Null, +) diff --git a/third_party/bigframes_vendored/sqlglot/generator.py b/third_party/bigframes_vendored/sqlglot/generator.py new file mode 100644 index 0000000000..f86f529d15 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/generator.py @@ -0,0 +1,5822 @@ +from __future__ import annotations + +from collections import defaultdict +from functools import reduce, wraps +import logging +import re +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.errors import ( + concat_messages, + ErrorLevel, + UnsupportedError, +) +from bigframes_vendored.sqlglot.helper import ( + apply_index_offset, + csv, + name_sequence, + seq_get, +) +from bigframes_vendored.sqlglot.jsonpath import ( + ALL_JSON_PATH_PARTS, + JSON_PATH_PART_TRANSFORMS, +) +from bigframes_vendored.sqlglot.time import format_time +from bigframes_vendored.sqlglot.tokens import TokenType + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + + G = t.TypeVar("G", bound="Generator") + GeneratorMethod = t.Callable[[G, E], str] + +logger = logging.getLogger("sqlglot") + +ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)") +UNSUPPORTED_TEMPLATE = ( + "Argument '{}' is not supported for expression '{}' when targeting {}." +) + + +def unsupported_args( + *args: t.Union[str, t.Tuple[str, str]], +) -> t.Callable[[GeneratorMethod], GeneratorMethod]: + """ + Decorator that can be used to mark certain args of an `Expression` subclass as unsupported. + It expects a sequence of argument names or pairs of the form (argument_name, diagnostic_msg). + """ + diagnostic_by_arg: t.Dict[str, t.Optional[str]] = {} + for arg in args: + if isinstance(arg, str): + diagnostic_by_arg[arg] = None + else: + diagnostic_by_arg[arg[0]] = arg[1] + + def decorator(func: GeneratorMethod) -> GeneratorMethod: + @wraps(func) + def _func(generator: G, expression: E) -> str: + expression_name = expression.__class__.__name__ + dialect_name = generator.dialect.__class__.__name__ + + for arg_name, diagnostic in diagnostic_by_arg.items(): + if expression.args.get(arg_name): + diagnostic = diagnostic or UNSUPPORTED_TEMPLATE.format( + arg_name, expression_name, dialect_name + ) + generator.unsupported(diagnostic) + + return func(generator, expression) + + return _func + + return decorator + + +class _Generator(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + # Remove transforms that correspond to unsupported JSONPathPart expressions + for part in ALL_JSON_PATH_PARTS - klass.SUPPORTED_JSON_PATH_PARTS: + klass.TRANSFORMS.pop(part, None) + + return klass + + +class Generator(metaclass=_Generator): + """ + Generator converts a given syntax tree to the corresponding SQL string. + + Args: + pretty: Whether to format the produced SQL string. + Default: False. + identify: Determines when an identifier should be quoted. Possible values are: + False (default): Never quote, except in cases where it's mandatory by the dialect. + True: Always quote except for specials cases. + 'safe': Only quote identifiers that are case insensitive. + normalize: Whether to normalize identifiers to lowercase. + Default: False. + pad: The pad size in a formatted string. For example, this affects the indentation of + a projection in a query, relative to its nesting level. + Default: 2. + indent: The indentation size in a formatted string. For example, this affects the + indentation of subqueries and filters under a `WHERE` clause. + Default: 2. + normalize_functions: How to normalize function names. Possible values are: + "upper" or True (default): Convert names to uppercase. + "lower": Convert names to lowercase. + False: Disables function name normalization. + unsupported_level: Determines the generator's behavior when it encounters unsupported expressions. + Default ErrorLevel.WARN. + max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError. + This is only relevant if unsupported_level is ErrorLevel.RAISE. + Default: 3 + leading_comma: Whether the comma is leading or trailing in select expressions. + This is only relevant when generating in pretty mode. + Default: False + max_text_width: The max number of characters in a segment before creating new lines in pretty mode. + The default is on the smaller end because the length only represents a segment and not the true + line length. + Default: 80 + comments: Whether to preserve comments in the output SQL code. + Default: True + """ + + TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { + **JSON_PATH_PART_TRANSFORMS, + exp.Adjacent: lambda self, e: self.binary(e, "-|-"), + exp.AllowedValuesProperty: lambda self, e: f"ALLOWED_VALUES {self.expressions(e, flat=True)}", + exp.AnalyzeColumns: lambda self, e: self.sql(e, "this"), + exp.AnalyzeWith: lambda self, e: self.expressions(e, prefix="WITH ", sep=" "), + exp.ArrayContainsAll: lambda self, e: self.binary(e, "@>"), + exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), + exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}", + exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}", + exp.CaseSpecificColumnConstraint: lambda _, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", + exp.Ceil: lambda self, e: self.ceil_floor(e), + exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", + exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", + exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})", + exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", + exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", + exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}", + exp.ConvertToCharset: lambda self, e: self.func( + "CONVERT", e.this, e.args["dest"], e.args.get("source") + ), + exp.CopyGrantsProperty: lambda *_: "COPY GRANTS", + exp.CredentialsProperty: lambda self, e: f"CREDENTIALS=({self.expressions(e, 'expressions', sep=' ')})", + exp.CurrentCatalog: lambda *_: "CURRENT_CATALOG", + exp.SessionUser: lambda *_: "SESSION_USER", + exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", + exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", + exp.DynamicProperty: lambda *_: "DYNAMIC", + exp.EmptyProperty: lambda *_: "EMPTY", + exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", + exp.EnviromentProperty: lambda self, e: f"ENVIRONMENT ({self.expressions(e, flat=True)})", + exp.EphemeralColumnConstraint: lambda self, e: f"EPHEMERAL{(' ' + self.sql(e, 'this')) if e.this else ''}", + exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}", + exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), + exp.Except: lambda self, e: self.set_operations(e), + exp.ExternalProperty: lambda *_: "EXTERNAL", + exp.Floor: lambda self, e: self.ceil_floor(e), + exp.Get: lambda self, e: self.get_put_sql(e), + exp.GlobalProperty: lambda *_: "GLOBAL", + exp.HeapProperty: lambda *_: "HEAP", + exp.IcebergProperty: lambda *_: "ICEBERG", + exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})", + exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", + exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}", + exp.Intersect: lambda self, e: self.set_operations(e), + exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}", + exp.Int64: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.BIGINT)), + exp.JSONBContainsAnyTopKeys: lambda self, e: self.binary(e, "?|"), + exp.JSONBContainsAllTopKeys: lambda self, e: self.binary(e, "?&"), + exp.JSONBDeleteAtPath: lambda self, e: self.binary(e, "#-"), + exp.LanguageProperty: lambda self, e: self.naked_property(e), + exp.LocationProperty: lambda self, e: self.naked_property(e), + exp.LogProperty: lambda _, e: f"{'NO ' if e.args.get('no') else ''}LOG", + exp.MaterializedProperty: lambda *_: "MATERIALIZED", + exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})", + exp.NoPrimaryIndexProperty: lambda *_: "NO PRIMARY INDEX", + exp.NotForReplicationColumnConstraint: lambda *_: "NOT FOR REPLICATION", + exp.OnCommitProperty: lambda _, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", + exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}", + exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", + exp.Operator: lambda self, e: self.binary( + e, "" + ), # The operator is produced in `binary` + exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}", + exp.ExtendsLeft: lambda self, e: self.binary(e, "&<"), + exp.ExtendsRight: lambda self, e: self.binary(e, "&>"), + exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", + exp.PartitionedByBucket: lambda self, e: self.func( + "BUCKET", e.this, e.expression + ), + exp.PartitionByTruncate: lambda self, e: self.func( + "TRUNCATE", e.this, e.expression + ), + exp.PivotAny: lambda self, e: f"ANY{self.sql(e, 'this')}", + exp.PositionalColumn: lambda self, e: f"#{self.sql(e, 'this')}", + exp.ProjectionPolicyColumnConstraint: lambda self, e: f"PROJECTION POLICY {self.sql(e, 'this')}", + exp.ZeroFillColumnConstraint: lambda self, e: "ZEROFILL", + exp.Put: lambda self, e: self.get_put_sql(e), + exp.RemoteWithConnectionModelProperty: lambda self, e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}", + exp.ReturnsProperty: lambda self, e: ( + "RETURNS NULL ON NULL INPUT" + if e.args.get("null") + else self.naked_property(e) + ), + exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}", + exp.SecureProperty: lambda *_: "SECURE", + exp.SecurityProperty: lambda self, e: f"SECURITY {self.sql(e, 'this')}", + exp.SetConfigProperty: lambda self, e: self.sql(e, "this"), + exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET", + exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", + exp.SharingProperty: lambda self, e: f"SHARING={self.sql(e, 'this')}", + exp.SqlReadWriteProperty: lambda _, e: e.name, + exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {self.sql(e, 'this')}", + exp.StabilityProperty: lambda _, e: e.name, + exp.Stream: lambda self, e: f"STREAM {self.sql(e, 'this')}", + exp.StreamingTableProperty: lambda *_: "STREAMING", + exp.StrictProperty: lambda *_: "STRICT", + exp.SwapTable: lambda self, e: f"SWAP WITH {self.sql(e, 'this')}", + exp.TableColumn: lambda self, e: self.sql(e.this), + exp.Tags: lambda self, e: f"TAG ({self.expressions(e, flat=True)})", + exp.TemporaryProperty: lambda *_: "TEMPORARY", + exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", + exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}", + exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", + exp.TransformModelProperty: lambda self, e: self.func( + "TRANSFORM", *e.expressions + ), + exp.TransientProperty: lambda *_: "TRANSIENT", + exp.Union: lambda self, e: self.set_operations(e), + exp.UnloggedProperty: lambda *_: "UNLOGGED", + exp.UsingTemplateProperty: lambda self, e: f"USING TEMPLATE {self.sql(e, 'this')}", + exp.UsingData: lambda self, e: f"USING DATA {self.sql(e, 'this')}", + exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE", + exp.UtcDate: lambda self, e: self.sql( + exp.CurrentDate(this=exp.Literal.string("UTC")) + ), + exp.UtcTime: lambda self, e: self.sql( + exp.CurrentTime(this=exp.Literal.string("UTC")) + ), + exp.UtcTimestamp: lambda self, e: self.sql( + exp.CurrentTimestamp(this=exp.Literal.string("UTC")) + ), + exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), + exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}", + exp.VolatileProperty: lambda *_: "VOLATILE", + exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", + exp.WithProcedureOptions: lambda self, e: f"WITH {self.expressions(e, flat=True)}", + exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}", + exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}", + exp.ForceProperty: lambda *_: "FORCE", + } + + # Whether null ordering is supported in order by + # True: Full Support, None: No support, False: No support for certain cases + # such as window specifications, aggregate functions etc + NULL_ORDERING_SUPPORTED: t.Optional[bool] = True + + # Whether ignore nulls is inside the agg or outside. + # FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER + IGNORE_NULLS_IN_FUNC = False + + # Whether locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported + LOCKING_READS_SUPPORTED = False + + # Whether the EXCEPT and INTERSECT operations can return duplicates + EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = True + + # Wrap derived values in parens, usually standard but spark doesn't support it + WRAP_DERIVED_VALUES = True + + # Whether create function uses an AS before the RETURN + CREATE_FUNCTION_RETURN_AS = True + + # Whether MERGE ... WHEN MATCHED BY SOURCE is allowed + MATCHED_BY_SOURCE = True + + # Whether the INTERVAL expression works only with values like '1 day' + SINGLE_STRING_INTERVAL = False + + # Whether the plural form of date parts like day (i.e. "days") is supported in INTERVALs + INTERVAL_ALLOWS_PLURAL_FORM = True + + # Whether limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") + LIMIT_FETCH = "ALL" + + # Whether limit and fetch allows expresions or just limits + LIMIT_ONLY_LITERALS = False + + # Whether a table is allowed to be renamed with a db + RENAME_TABLE_WITH_DB = True + + # The separator for grouping sets and rollups + GROUPINGS_SEP = "," + + # The string used for creating an index on a table + INDEX_ON = "ON" + + # Whether join hints should be generated + JOIN_HINTS = True + + # Whether table hints should be generated + TABLE_HINTS = True + + # Whether query hints should be generated + QUERY_HINTS = True + + # What kind of separator to use for query hints + QUERY_HINT_SEP = ", " + + # Whether comparing against booleans (e.g. x IS TRUE) is supported + IS_BOOL_ALLOWED = True + + # Whether to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement + DUPLICATE_KEY_UPDATE_WITH_SET = True + + # Whether to generate the limit as TOP instead of LIMIT + LIMIT_IS_TOP = False + + # Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ... + RETURNING_END = True + + # Whether to generate an unquoted value for EXTRACT's date part argument + EXTRACT_ALLOWS_QUOTES = True + + # Whether TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax + TZ_TO_WITH_TIME_ZONE = False + + # Whether the NVL2 function is supported + NVL2_SUPPORTED = True + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax + SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") + + # Whether VALUES statements can be used as derived tables. + # MySQL 5 and Redshift do not allow this, so when False, it will convert + # SELECT * VALUES into SELECT UNION + VALUES_AS_TABLE = True + + # Whether the word COLUMN is included when adding a column with ALTER TABLE + ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True + + # UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery) + UNNEST_WITH_ORDINALITY = True + + # Whether FILTER (WHERE cond) can be used for conditional aggregation + AGGREGATE_FILTER_SUPPORTED = True + + # Whether JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds + SEMI_ANTI_JOIN_WITH_SIDE = True + + # Whether to include the type of a computed column in the CREATE DDL + COMPUTED_COLUMN_WITH_TYPE = True + + # Whether CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY + SUPPORTS_TABLE_COPY = True + + # Whether parentheses are required around the table sample's expression + TABLESAMPLE_REQUIRES_PARENS = True + + # Whether a table sample clause's size needs to be followed by the ROWS keyword + TABLESAMPLE_SIZE_IS_ROWS = True + + # The keyword(s) to use when generating a sample clause + TABLESAMPLE_KEYWORDS = "TABLESAMPLE" + + # Whether the TABLESAMPLE clause supports a method name, like BERNOULLI + TABLESAMPLE_WITH_METHOD = True + + # The keyword to use when specifying the seed of a sample clause + TABLESAMPLE_SEED_KEYWORD = "SEED" + + # Whether COLLATE is a function instead of a binary operator + COLLATE_IS_FUNC = False + + # Whether data types support additional specifiers like e.g. CHAR or BYTE (oracle) + DATA_TYPE_SPECIFIERS_ALLOWED = False + + # Whether conditions require booleans WHERE x = 0 vs WHERE x + ENSURE_BOOLS = False + + # Whether the "RECURSIVE" keyword is required when defining recursive CTEs + CTE_RECURSIVE_KEYWORD_REQUIRED = True + + # Whether CONCAT requires >1 arguments + SUPPORTS_SINGLE_ARG_CONCAT = True + + # Whether LAST_DAY function supports a date part argument + LAST_DAY_SUPPORTS_DATE_PART = True + + # Whether named columns are allowed in table aliases + SUPPORTS_TABLE_ALIAS_COLUMNS = True + + # Whether UNPIVOT aliases are Identifiers (False means they're Literals) + UNPIVOT_ALIASES_ARE_IDENTIFIERS = True + + # What delimiter to use for separating JSON key/value pairs + JSON_KEY_VALUE_PAIR_SEP = ":" + + # INSERT OVERWRITE TABLE x override + INSERT_OVERWRITE = " OVERWRITE TABLE" + + # Whether the SELECT .. INTO syntax is used instead of CTAS + SUPPORTS_SELECT_INTO = False + + # Whether UNLOGGED tables can be created + SUPPORTS_UNLOGGED_TABLES = False + + # Whether the CREATE TABLE LIKE statement is supported + SUPPORTS_CREATE_TABLE_LIKE = True + + # Whether the LikeProperty needs to be specified inside of the schema clause + LIKE_PROPERTY_INSIDE_SCHEMA = False + + # Whether DISTINCT can be followed by multiple args in an AggFunc. If not, it will be + # transpiled into a series of CASE-WHEN-ELSE, ultimately using a tuple conseisting of the args + MULTI_ARG_DISTINCT = True + + # Whether the JSON extraction operators expect a value of type JSON + JSON_TYPE_REQUIRED_FOR_EXTRACTION = False + + # Whether bracketed keys like ["foo"] are supported in JSON paths + JSON_PATH_BRACKETED_KEY_SUPPORTED = True + + # Whether to escape keys using single quotes in JSON paths + JSON_PATH_SINGLE_QUOTE_ESCAPE = False + + # The JSONPathPart expressions supported by this dialect + SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy() + + # Whether any(f(x) for x in array) can be implemented by this dialect + CAN_IMPLEMENT_ARRAY_ANY = False + + # Whether the function TO_NUMBER is supported + SUPPORTS_TO_NUMBER = True + + # Whether EXCLUDE in window specification is supported + SUPPORTS_WINDOW_EXCLUDE = False + + # Whether or not set op modifiers apply to the outer set op or select. + # SELECT * FROM x UNION SELECT * FROM y LIMIT 1 + # True means limit 1 happens after the set op, False means it it happens on y. + SET_OP_MODIFIERS = True + + # Whether parameters from COPY statement are wrapped in parentheses + COPY_PARAMS_ARE_WRAPPED = True + + # Whether values of params are set with "=" token or empty space + COPY_PARAMS_EQ_REQUIRED = False + + # Whether COPY statement has INTO keyword + COPY_HAS_INTO_KEYWORD = True + + # Whether the conditional TRY(expression) function is supported + TRY_SUPPORTED = True + + # Whether the UESCAPE syntax in unicode strings is supported + SUPPORTS_UESCAPE = True + + # Function used to replace escaped unicode codes in unicode strings + UNICODE_SUBSTITUTE: t.Optional[t.Callable[[re.Match[str]], str]] = None + + # The keyword to use when generating a star projection with excluded columns + STAR_EXCEPT = "EXCEPT" + + # The HEX function name + HEX_FUNC = "HEX" + + # The keywords to use when prefixing & separating WITH based properties + WITH_PROPERTIES_PREFIX = "WITH" + + # Whether to quote the generated expression of exp.JsonPath + QUOTE_JSON_PATH = True + + # Whether the text pattern/fill (3rd) parameter of RPAD()/LPAD() is optional (defaults to space) + PAD_FILL_PATTERN_IS_REQUIRED = False + + # Whether a projection can explode into multiple rows, e.g. by unnesting an array. + SUPPORTS_EXPLODING_PROJECTIONS = True + + # Whether ARRAY_CONCAT can be generated with varlen args or if it should be reduced to 2-arg version + ARRAY_CONCAT_IS_VAR_LEN = True + + # Whether CONVERT_TIMEZONE() is supported; if not, it will be generated as exp.AtTimeZone + SUPPORTS_CONVERT_TIMEZONE = False + + # Whether MEDIAN(expr) is supported; if not, it will be generated as PERCENTILE_CONT(expr, 0.5) + SUPPORTS_MEDIAN = True + + # Whether UNIX_SECONDS(timestamp) is supported + SUPPORTS_UNIX_SECONDS = False + + # Whether to wrap in `AlterSet`, e.g., ALTER ... SET () + ALTER_SET_WRAPPED = False + + # Whether to normalize the date parts in EXTRACT( FROM ) into a common representation + # For instance, to extract the day of week in ISO semantics, one can use ISODOW, DAYOFWEEKISO etc depending on the dialect. + # TODO: The normalization should be done by default once we've tested it across all dialects. + NORMALIZE_EXTRACT_DATE_PARTS = False + + # The name to generate for the JSONPath expression. If `None`, only `this` will be generated + PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON" + + # The function name of the exp.ArraySize expression + ARRAY_SIZE_NAME: str = "ARRAY_LENGTH" + + # The syntax to use when altering the type of a column + ALTER_SET_TYPE = "SET DATA TYPE" + + # Whether exp.ArraySize should generate the dimension arg too (valid for Postgres & DuckDB) + # None -> Doesn't support it at all + # False (DuckDB) -> Has backwards-compatible support, but preferably generated without + # True (Postgres) -> Explicitly requires it + ARRAY_SIZE_DIM_REQUIRED: t.Optional[bool] = None + + # Whether a multi-argument DECODE(...) function is supported. If not, a CASE expression is generated + SUPPORTS_DECODE_CASE = True + + # Whether SYMMETRIC and ASYMMETRIC flags are supported with BETWEEN expression + SUPPORTS_BETWEEN_FLAGS = False + + # Whether LIKE and ILIKE support quantifiers such as LIKE ANY/ALL/SOME + SUPPORTS_LIKE_QUANTIFIERS = True + + # Prefix which is appended to exp.Table expressions in MATCH AGAINST + MATCH_AGAINST_TABLE_PREFIX: t.Optional[str] = None + + # Whether to include the VARIABLE keyword for SET assignments + SET_ASSIGNMENT_REQUIRES_VARIABLE_KEYWORD = False + + TYPE_MAPPING = { + exp.DataType.Type.DATETIME2: "TIMESTAMP", + exp.DataType.Type.NCHAR: "CHAR", + exp.DataType.Type.NVARCHAR: "VARCHAR", + exp.DataType.Type.MEDIUMTEXT: "TEXT", + exp.DataType.Type.LONGTEXT: "TEXT", + exp.DataType.Type.TINYTEXT: "TEXT", + exp.DataType.Type.BLOB: "VARBINARY", + exp.DataType.Type.MEDIUMBLOB: "BLOB", + exp.DataType.Type.LONGBLOB: "BLOB", + exp.DataType.Type.TINYBLOB: "BLOB", + exp.DataType.Type.INET: "INET", + exp.DataType.Type.ROWVERSION: "VARBINARY", + exp.DataType.Type.SMALLDATETIME: "TIMESTAMP", + } + + UNSUPPORTED_TYPES: set[exp.DataType.Type] = set() + + TIME_PART_SINGULARS = { + "MICROSECONDS": "MICROSECOND", + "SECONDS": "SECOND", + "MINUTES": "MINUTE", + "HOURS": "HOUR", + "DAYS": "DAY", + "WEEKS": "WEEK", + "MONTHS": "MONTH", + "QUARTERS": "QUARTER", + "YEARS": "YEAR", + } + + AFTER_HAVING_MODIFIER_TRANSFORMS = { + "cluster": lambda self, e: self.sql(e, "cluster"), + "distribute": lambda self, e: self.sql(e, "distribute"), + "sort": lambda self, e: self.sql(e, "sort"), + "windows": lambda self, e: ( + self.seg("WINDOW ") + self.expressions(e, key="windows", flat=True) + if e.args.get("windows") + else "" + ), + "qualify": lambda self, e: self.sql(e, "qualify"), + } + + TOKEN_MAPPING: t.Dict[TokenType, str] = {} + + STRUCT_DELIMITER = ("<", ">") + + PARAMETER_TOKEN = "@" + NAMED_PLACEHOLDER_TOKEN = ":" + + EXPRESSION_PRECEDES_PROPERTIES_CREATABLES: t.Set[str] = set() + + PROPERTIES_LOCATION = { + exp.AllowedValuesProperty: exp.Properties.Location.POST_SCHEMA, + exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, + exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA, + exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA, + exp.BackupProperty: exp.Properties.Location.POST_SCHEMA, + exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, + exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA, + exp.ChecksumProperty: exp.Properties.Location.POST_NAME, + exp.CollateProperty: exp.Properties.Location.POST_SCHEMA, + exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA, + exp.Cluster: exp.Properties.Location.POST_SCHEMA, + exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA, + exp.DistributedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.DuplicateKeyProperty: exp.Properties.Location.POST_SCHEMA, + exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME, + exp.DataDeletionProperty: exp.Properties.Location.POST_SCHEMA, + exp.DefinerProperty: exp.Properties.Location.POST_CREATE, + exp.DictRange: exp.Properties.Location.POST_SCHEMA, + exp.DictProperty: exp.Properties.Location.POST_SCHEMA, + exp.DynamicProperty: exp.Properties.Location.POST_CREATE, + exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA, + exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA, + exp.EmptyProperty: exp.Properties.Location.POST_SCHEMA, + exp.EncodeProperty: exp.Properties.Location.POST_EXPRESSION, + exp.EngineProperty: exp.Properties.Location.POST_SCHEMA, + exp.EnviromentProperty: exp.Properties.Location.POST_SCHEMA, + exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA, + exp.ExternalProperty: exp.Properties.Location.POST_CREATE, + exp.FallbackProperty: exp.Properties.Location.POST_NAME, + exp.FileFormatProperty: exp.Properties.Location.POST_WITH, + exp.FreespaceProperty: exp.Properties.Location.POST_NAME, + exp.GlobalProperty: exp.Properties.Location.POST_CREATE, + exp.HeapProperty: exp.Properties.Location.POST_WITH, + exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA, + exp.IcebergProperty: exp.Properties.Location.POST_CREATE, + exp.IncludeProperty: exp.Properties.Location.POST_SCHEMA, + exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA, + exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME, + exp.JournalProperty: exp.Properties.Location.POST_NAME, + exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA, + exp.LikeProperty: exp.Properties.Location.POST_SCHEMA, + exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, + exp.LockProperty: exp.Properties.Location.POST_SCHEMA, + exp.LockingProperty: exp.Properties.Location.POST_ALIAS, + exp.LogProperty: exp.Properties.Location.POST_NAME, + exp.MaterializedProperty: exp.Properties.Location.POST_CREATE, + exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, + exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION, + exp.OnProperty: exp.Properties.Location.POST_SCHEMA, + exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION, + exp.Order: exp.Properties.Location.POST_SCHEMA, + exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA, + exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, + exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA, + exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA, + exp.Property: exp.Properties.Location.POST_WITH, + exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA, + exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, + exp.SampleProperty: exp.Properties.Location.POST_SCHEMA, + exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, + exp.SecureProperty: exp.Properties.Location.POST_CREATE, + exp.SecurityProperty: exp.Properties.Location.POST_SCHEMA, + exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, + exp.Set: exp.Properties.Location.POST_SCHEMA, + exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA, + exp.SetProperty: exp.Properties.Location.POST_CREATE, + exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA, + exp.SharingProperty: exp.Properties.Location.POST_EXPRESSION, + exp.SequenceProperties: exp.Properties.Location.POST_EXPRESSION, + exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, + exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA, + exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, + exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, + exp.StorageHandlerProperty: exp.Properties.Location.POST_SCHEMA, + exp.StreamingTableProperty: exp.Properties.Location.POST_CREATE, + exp.StrictProperty: exp.Properties.Location.POST_SCHEMA, + exp.Tags: exp.Properties.Location.POST_WITH, + exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, + exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA, + exp.TransientProperty: exp.Properties.Location.POST_CREATE, + exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA, + exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA, + exp.UnloggedProperty: exp.Properties.Location.POST_CREATE, + exp.UsingTemplateProperty: exp.Properties.Location.POST_SCHEMA, + exp.ViewAttributeProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.POST_CREATE, + exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, + exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, + exp.WithProcedureOptions: exp.Properties.Location.POST_SCHEMA, + exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA, + exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA, + exp.ForceProperty: exp.Properties.Location.POST_CREATE, + } + + # Keywords that can't be used as unquoted identifier names + RESERVED_KEYWORDS: t.Set[str] = set() + + # Expressions whose comments are separated from them for better formatting + WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Command, + exp.Create, + exp.Describe, + exp.Delete, + exp.Drop, + exp.From, + exp.Insert, + exp.Join, + exp.MultitableInserts, + exp.Order, + exp.Group, + exp.Having, + exp.Select, + exp.SetOperation, + exp.Update, + exp.Where, + exp.With, + ) + + # Expressions that should not have their comments generated in maybe_comment + EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Binary, + exp.SetOperation, + ) + + # Expressions that can remain unwrapped when appearing in the context of an INTERVAL + UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Column, + exp.Literal, + exp.Neg, + exp.Paren, + ) + + PARAMETERIZABLE_TEXT_TYPES = { + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.CHAR, + exp.DataType.Type.NCHAR, + } + + # Expressions that need to have all CTEs under them bubbled up to them + EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set() + + RESPECT_IGNORE_NULLS_UNSUPPORTED_EXPRESSIONS: t.Tuple[ + t.Type[exp.Expression], ... + ] = () + + SAFE_JSON_PATH_KEY_RE = exp.SAFE_IDENTIFIER_RE + + SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" + + __slots__ = ( + "pretty", + "identify", + "normalize", + "pad", + "_indent", + "normalize_functions", + "unsupported_level", + "max_unsupported", + "leading_comma", + "max_text_width", + "comments", + "dialect", + "unsupported_messages", + "_escaped_quote_end", + "_escaped_byte_quote_end", + "_escaped_identifier_end", + "_next_name", + "_identifier_start", + "_identifier_end", + "_quote_json_path_key_using_brackets", + ) + + def __init__( + self, + pretty: t.Optional[bool] = None, + identify: str | bool = False, + normalize: bool = False, + pad: int = 2, + indent: int = 2, + normalize_functions: t.Optional[str | bool] = None, + unsupported_level: ErrorLevel = ErrorLevel.WARN, + max_unsupported: int = 3, + leading_comma: bool = False, + max_text_width: int = 80, + comments: bool = True, + dialect: DialectType = None, + ): + import bigframes_vendored.sqlglot + from bigframes_vendored.sqlglot.dialects import Dialect + + self.pretty = ( + pretty if pretty is not None else bigframes_vendored.sqlglot.pretty + ) + self.identify = identify + self.normalize = normalize + self.pad = pad + self._indent = indent + self.unsupported_level = unsupported_level + self.max_unsupported = max_unsupported + self.leading_comma = leading_comma + self.max_text_width = max_text_width + self.comments = comments + self.dialect = Dialect.get_or_raise(dialect) + + # This is both a Dialect property and a Generator argument, so we prioritize the latter + self.normalize_functions = ( + self.dialect.NORMALIZE_FUNCTIONS + if normalize_functions is None + else normalize_functions + ) + + self.unsupported_messages: t.List[str] = [] + self._escaped_quote_end: str = ( + self.dialect.tokenizer_class.STRING_ESCAPES[0] + self.dialect.QUOTE_END + ) + self._escaped_byte_quote_end: str = ( + self.dialect.tokenizer_class.STRING_ESCAPES[0] + self.dialect.BYTE_END + if self.dialect.BYTE_END + else "" + ) + self._escaped_identifier_end = self.dialect.IDENTIFIER_END * 2 + + self._next_name = name_sequence("_t") + + self._identifier_start = self.dialect.IDENTIFIER_START + self._identifier_end = self.dialect.IDENTIFIER_END + + self._quote_json_path_key_using_brackets = True + + def generate(self, expression: exp.Expression, copy: bool = True) -> str: + """ + Generates the SQL string corresponding to the given syntax tree. + + Args: + expression: The syntax tree. + copy: Whether to copy the expression. The generator performs mutations so + it is safer to copy. + + Returns: + The SQL string corresponding to `expression`. + """ + if copy: + expression = expression.copy() + + expression = self.preprocess(expression) + + self.unsupported_messages = [] + sql = self.sql(expression).strip() + + if self.pretty: + sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n") + + if self.unsupported_level == ErrorLevel.IGNORE: + return sql + + if self.unsupported_level == ErrorLevel.WARN: + for msg in self.unsupported_messages: + logger.warning(msg) + elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: + raise UnsupportedError( + concat_messages(self.unsupported_messages, self.max_unsupported) + ) + + return sql + + def preprocess(self, expression: exp.Expression) -> exp.Expression: + """Apply generic preprocessing transformations to a given expression.""" + expression = self._move_ctes_to_top_level(expression) + + if self.ENSURE_BOOLS: + from bigframes_vendored.sqlglot.transforms import ensure_bools + + expression = ensure_bools(expression) + + return expression + + def _move_ctes_to_top_level(self, expression: E) -> E: + if ( + not expression.parent + and type(expression) in self.EXPRESSIONS_WITHOUT_NESTED_CTES + and any( + node.parent is not expression for node in expression.find_all(exp.With) + ) + ): + from bigframes_vendored.sqlglot.transforms import move_ctes_to_top_level + + expression = move_ctes_to_top_level(expression) + return expression + + def unsupported(self, message: str) -> None: + if self.unsupported_level == ErrorLevel.IMMEDIATE: + raise UnsupportedError(message) + self.unsupported_messages.append(message) + + def sep(self, sep: str = " ") -> str: + return f"{sep.strip()}\n" if self.pretty else sep + + def seg(self, sql: str, sep: str = " ") -> str: + return f"{self.sep(sep)}{sql}" + + def sanitize_comment(self, comment: str) -> str: + comment = " " + comment if comment[0].strip() else comment + comment = comment + " " if comment[-1].strip() else comment + + if not self.dialect.tokenizer_class.NESTED_COMMENTS: + # Necessary workaround to avoid syntax errors due to nesting: /* ... */ ... */ + comment = comment.replace("*/", "* /") + + return comment + + def maybe_comment( + self, + sql: str, + expression: t.Optional[exp.Expression] = None, + comments: t.Optional[t.List[str]] = None, + separated: bool = False, + ) -> str: + comments = ( + ((expression and expression.comments) if comments is None else comments) # type: ignore + if self.comments + else None + ) + + if not comments or isinstance(expression, self.EXCLUDE_COMMENTS): + return sql + + comments_sql = " ".join( + f"/*{self.sanitize_comment(comment)}*/" for comment in comments if comment + ) + + if not comments_sql: + return sql + + comments_sql = self._replace_line_breaks(comments_sql) + + if separated or isinstance(expression, self.WITH_SEPARATED_COMMENTS): + return ( + f"{self.sep()}{comments_sql}{sql}" + if not sql or sql[0].isspace() + else f"{comments_sql}{self.sep()}{sql}" + ) + + return f"{sql} {comments_sql}" + + def wrap(self, expression: exp.Expression | str) -> str: + this_sql = ( + self.sql(expression) + if isinstance(expression, exp.UNWRAPPED_QUERIES) + else self.sql(expression, "this") + ) + if not this_sql: + return "()" + + this_sql = self.indent(this_sql, level=1, pad=0) + return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" + + def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str: + original = self.identify + self.identify = False + result = func(*args, **kwargs) + self.identify = original + return result + + def normalize_func(self, name: str) -> str: + if self.normalize_functions == "upper" or self.normalize_functions is True: + return name.upper() + if self.normalize_functions == "lower": + return name.lower() + return name + + def indent( + self, + sql: str, + level: int = 0, + pad: t.Optional[int] = None, + skip_first: bool = False, + skip_last: bool = False, + ) -> str: + if not self.pretty or not sql: + return sql + + pad = self.pad if pad is None else pad + lines = sql.split("\n") + + return "\n".join( + ( + line + if (skip_first and i == 0) or (skip_last and i == len(lines) - 1) + else f"{' ' * (level * self._indent + pad)}{line}" + ) + for i, line in enumerate(lines) + ) + + def sql( + self, + expression: t.Optional[str | exp.Expression], + key: t.Optional[str] = None, + comment: bool = True, + ) -> str: + if not expression: + return "" + + if isinstance(expression, str): + return expression + + if key: + value = expression.args.get(key) + if value: + return self.sql(value) + return "" + + transform = self.TRANSFORMS.get(expression.__class__) + + if callable(transform): + sql = transform(self, expression) + elif isinstance(expression, exp.Expression): + exp_handler_name = f"{expression.key}_sql" + + if hasattr(self, exp_handler_name): + sql = getattr(self, exp_handler_name)(expression) + elif isinstance(expression, exp.Func): + sql = self.function_fallback_sql(expression) + elif isinstance(expression, exp.Property): + sql = self.property_sql(expression) + else: + raise ValueError( + f"Unsupported expression type {expression.__class__.__name__}" + ) + else: + raise ValueError( + f"Expected an Expression. Received {type(expression)}: {expression}" + ) + + return self.maybe_comment(sql, expression) if self.comments and comment else sql + + def uncache_sql(self, expression: exp.Uncache) -> str: + table = self.sql(expression, "this") + exists_sql = " IF EXISTS" if expression.args.get("exists") else "" + return f"UNCACHE TABLE{exists_sql} {table}" + + def cache_sql(self, expression: exp.Cache) -> str: + lazy = " LAZY" if expression.args.get("lazy") else "" + table = self.sql(expression, "this") + options = expression.args.get("options") + options = ( + f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" + if options + else "" + ) + sql = self.sql(expression, "expression") + sql = f" AS{self.sep()}{sql}" if sql else "" + sql = f"CACHE{lazy} TABLE {table}{options}{sql}" + return self.prepend_ctes(expression, sql) + + def characterset_sql(self, expression: exp.CharacterSet) -> str: + if isinstance(expression.parent, exp.Cast): + return f"CHAR CHARACTER SET {self.sql(expression, 'this')}" + default = "DEFAULT " if expression.args.get("default") else "" + return f"{default}CHARACTER SET={self.sql(expression, 'this')}" + + def column_parts(self, expression: exp.Column) -> str: + return ".".join( + self.sql(part) + for part in ( + expression.args.get("catalog"), + expression.args.get("db"), + expression.args.get("table"), + expression.args.get("this"), + ) + if part + ) + + def column_sql(self, expression: exp.Column) -> str: + join_mark = " (+)" if expression.args.get("join_mark") else "" + + if join_mark and not self.dialect.SUPPORTS_COLUMN_JOIN_MARKS: + join_mark = "" + self.unsupported( + "Outer join syntax using the (+) operator is not supported." + ) + + return f"{self.column_parts(expression)}{join_mark}" + + def pseudocolumn_sql(self, expression: exp.Pseudocolumn) -> str: + return self.column_sql(expression) + + def columnposition_sql(self, expression: exp.ColumnPosition) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + position = self.sql(expression, "position") + return f"{position}{this}" + + def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: + column = self.sql(expression, "this") + kind = self.sql(expression, "kind") + constraints = self.expressions( + expression, key="constraints", sep=" ", flat=True + ) + exists = "IF NOT EXISTS " if expression.args.get("exists") else "" + kind = f"{sep}{kind}" if kind else "" + constraints = f" {constraints}" if constraints else "" + position = self.sql(expression, "position") + position = f" {position}" if position else "" + + if ( + expression.find(exp.ComputedColumnConstraint) + and not self.COMPUTED_COLUMN_WITH_TYPE + ): + kind = "" + + return f"{exists}{column}{kind}{constraints}{position}" + + def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: + this = self.sql(expression, "this") + kind_sql = self.sql(expression, "kind").strip() + return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql + + def computedcolumnconstraint_sql( + self, expression: exp.ComputedColumnConstraint + ) -> str: + this = self.sql(expression, "this") + if expression.args.get("not_null"): + persisted = " PERSISTED NOT NULL" + elif expression.args.get("persisted"): + persisted = " PERSISTED" + else: + persisted = "" + + return f"AS {this}{persisted}" + + def autoincrementcolumnconstraint_sql(self, _) -> str: + return self.token_sql(TokenType.AUTO_INCREMENT) + + def compresscolumnconstraint_sql( + self, expression: exp.CompressColumnConstraint + ) -> str: + if isinstance(expression.this, list): + this = self.wrap(self.expressions(expression, key="this", flat=True)) + else: + this = self.sql(expression, "this") + + return f"COMPRESS {this}" + + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: + this = "" + if expression.this is not None: + on_null = " ON NULL" if expression.args.get("on_null") else "" + this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}" + + start = expression.args.get("start") + start = f"START WITH {start}" if start else "" + increment = expression.args.get("increment") + increment = f" INCREMENT BY {increment}" if increment else "" + minvalue = expression.args.get("minvalue") + minvalue = f" MINVALUE {minvalue}" if minvalue else "" + maxvalue = expression.args.get("maxvalue") + maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" + cycle = expression.args.get("cycle") + cycle_sql = "" + + if cycle is not None: + cycle_sql = f"{' NO' if not cycle else ''} CYCLE" + cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql + + sequence_opts = "" + if start or increment or cycle_sql: + sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}" + sequence_opts = f" ({sequence_opts.strip()})" + + expr = self.sql(expression, "expression") + expr = f"({expr})" if expr else "IDENTITY" + + return f"GENERATED{this} AS {expr}{sequence_opts}" + + def generatedasrowcolumnconstraint_sql( + self, expression: exp.GeneratedAsRowColumnConstraint + ) -> str: + start = "START" if expression.args.get("start") else "END" + hidden = " HIDDEN" if expression.args.get("hidden") else "" + return f"GENERATED ALWAYS AS ROW {start}{hidden}" + + def periodforsystemtimeconstraint_sql( + self, expression: exp.PeriodForSystemTimeConstraint + ) -> str: + return f"PERIOD FOR SYSTEM_TIME ({self.sql(expression, 'this')}, {self.sql(expression, 'expression')})" + + def notnullcolumnconstraint_sql( + self, expression: exp.NotNullColumnConstraint + ) -> str: + return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" + + def primarykeycolumnconstraint_sql( + self, expression: exp.PrimaryKeyColumnConstraint + ) -> str: + desc = expression.args.get("desc") + if desc is not None: + return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" + options = self.expressions(expression, key="options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"PRIMARY KEY{options}" + + def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + index_type = expression.args.get("index_type") + index_type = f" USING {index_type}" if index_type else "" + on_conflict = self.sql(expression, "on_conflict") + on_conflict = f" {on_conflict}" if on_conflict else "" + nulls_sql = " NULLS NOT DISTINCT" if expression.args.get("nulls") else "" + options = self.expressions(expression, key="options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"UNIQUE{nulls_sql}{this}{index_type}{on_conflict}{options}" + + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: + return self.sql(expression, "this") + + def create_sql(self, expression: exp.Create) -> str: + kind = self.sql(expression, "kind") + kind = self.dialect.INVERSE_CREATABLE_KIND_MAPPING.get(kind) or kind + properties = expression.args.get("properties") + properties_locs = ( + self.locate_properties(properties) if properties else defaultdict() + ) + + this = self.createable_sql(expression, properties_locs) + + properties_sql = "" + if properties_locs.get( + exp.Properties.Location.POST_SCHEMA + ) or properties_locs.get(exp.Properties.Location.POST_WITH): + props_ast = exp.Properties( + expressions=[ + *properties_locs[exp.Properties.Location.POST_SCHEMA], + *properties_locs[exp.Properties.Location.POST_WITH], + ] + ) + props_ast.parent = expression + properties_sql = self.sql(props_ast) + + if properties_locs.get(exp.Properties.Location.POST_SCHEMA): + properties_sql = self.sep() + properties_sql + elif not self.pretty: + # Standalone POST_WITH properties need a leading whitespace in non-pretty mode + properties_sql = f" {properties_sql}" + + begin = " BEGIN" if expression.args.get("begin") else "" + end = " END" if expression.args.get("end") else "" + + expression_sql = self.sql(expression, "expression") + if expression_sql: + expression_sql = f"{begin}{self.sep()}{expression_sql}{end}" + + if self.CREATE_FUNCTION_RETURN_AS or not isinstance( + expression.expression, exp.Return + ): + postalias_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_ALIAS): + postalias_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[ + exp.Properties.Location.POST_ALIAS + ] + ), + wrapped=False, + ) + postalias_props_sql = ( + f" {postalias_props_sql}" if postalias_props_sql else "" + ) + expression_sql = f" AS{postalias_props_sql}{expression_sql}" + + postindex_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_INDEX): + postindex_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_INDEX] + ), + wrapped=False, + prefix=" ", + ) + + indexes = self.expressions(expression, key="indexes", indent=False, sep=" ") + indexes = f" {indexes}" if indexes else "" + index_sql = indexes + postindex_props_sql + + replace = " OR REPLACE" if expression.args.get("replace") else "" + refresh = " OR REFRESH" if expression.args.get("refresh") else "" + unique = " UNIQUE" if expression.args.get("unique") else "" + + clustered = expression.args.get("clustered") + if clustered is None: + clustered_sql = "" + elif clustered: + clustered_sql = " CLUSTERED COLUMNSTORE" + else: + clustered_sql = " NONCLUSTERED COLUMNSTORE" + + postcreate_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_CREATE): + postcreate_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_CREATE] + ), + sep=" ", + prefix=" ", + wrapped=False, + ) + + modifiers = "".join( + (clustered_sql, replace, refresh, unique, postcreate_props_sql) + ) + + postexpression_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_EXPRESSION): + postexpression_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_EXPRESSION] + ), + sep=" ", + prefix=" ", + wrapped=False, + ) + + concurrently = " CONCURRENTLY" if expression.args.get("concurrently") else "" + exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" + no_schema_binding = ( + " WITH NO SCHEMA BINDING" + if expression.args.get("no_schema_binding") + else "" + ) + + clone = self.sql(expression, "clone") + clone = f" {clone}" if clone else "" + + if kind in self.EXPRESSION_PRECEDES_PROPERTIES_CREATABLES: + properties_expression = f"{expression_sql}{properties_sql}" + else: + properties_expression = f"{properties_sql}{expression_sql}" + + expression_sql = f"CREATE{modifiers} {kind}{concurrently}{exists_sql} {this}{properties_expression}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}" + return self.prepend_ctes(expression, expression_sql) + + def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str: + start = self.sql(expression, "start") + start = f"START WITH {start}" if start else "" + increment = self.sql(expression, "increment") + increment = f" INCREMENT BY {increment}" if increment else "" + minvalue = self.sql(expression, "minvalue") + minvalue = f" MINVALUE {minvalue}" if minvalue else "" + maxvalue = self.sql(expression, "maxvalue") + maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" + owned = self.sql(expression, "owned") + owned = f" OWNED BY {owned}" if owned else "" + + cache = expression.args.get("cache") + if cache is None: + cache_str = "" + elif cache is True: + cache_str = " CACHE" + else: + cache_str = f" CACHE {cache}" + + options = self.expressions(expression, key="options", flat=True, sep=" ") + options = f" {options}" if options else "" + + return f"{start}{increment}{minvalue}{maxvalue}{cache_str}{options}{owned}".lstrip() + + def clone_sql(self, expression: exp.Clone) -> str: + this = self.sql(expression, "this") + shallow = "SHALLOW " if expression.args.get("shallow") else "" + keyword = ( + "COPY" + if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY + else "CLONE" + ) + return f"{shallow}{keyword} {this}" + + def describe_sql(self, expression: exp.Describe) -> str: + style = expression.args.get("style") + style = f" {style}" if style else "" + partition = self.sql(expression, "partition") + partition = f" {partition}" if partition else "" + format = self.sql(expression, "format") + format = f" {format}" if format else "" + + return f"DESCRIBE{style}{format} {self.sql(expression, 'this')}{partition}" + + def heredoc_sql(self, expression: exp.Heredoc) -> str: + tag = self.sql(expression, "tag") + return f"${tag}${self.sql(expression, 'this')}${tag}$" + + def prepend_ctes(self, expression: exp.Expression, sql: str) -> str: + with_ = self.sql(expression, "with_") + if with_: + sql = f"{with_}{self.sep()}{sql}" + return sql + + def with_sql(self, expression: exp.With) -> str: + sql = self.expressions(expression, flat=True) + recursive = ( + "RECURSIVE " + if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive") + else "" + ) + search = self.sql(expression, "search") + search = f" {search}" if search else "" + + return f"WITH {recursive}{sql}{search}" + + def cte_sql(self, expression: exp.CTE) -> str: + alias = expression.args.get("alias") + if alias: + alias.add_comments(expression.pop_comments()) + + alias_sql = self.sql(expression, "alias") + + materialized = expression.args.get("materialized") + if materialized is False: + materialized = "NOT MATERIALIZED " + elif materialized: + materialized = "MATERIALIZED " + + key_expressions = self.expressions(expression, key="key_expressions", flat=True) + key_expressions = f" USING KEY ({key_expressions})" if key_expressions else "" + + return f"{alias_sql}{key_expressions} AS {materialized or ''}{self.wrap(expression)}" + + def tablealias_sql(self, expression: exp.TableAlias) -> str: + alias = self.sql(expression, "this") + columns = self.expressions(expression, key="columns", flat=True) + columns = f"({columns})" if columns else "" + + if columns and not self.SUPPORTS_TABLE_ALIAS_COLUMNS: + columns = "" + self.unsupported("Named columns are not supported in table alias.") + + if not alias and not self.dialect.UNNEST_COLUMN_ONLY: + alias = self._next_name() + + return f"{alias}{columns}" + + def bitstring_sql(self, expression: exp.BitString) -> str: + this = self.sql(expression, "this") + if self.dialect.BIT_START: + return f"{self.dialect.BIT_START}{this}{self.dialect.BIT_END}" + return f"{int(this, 2)}" + + def hexstring_sql( + self, expression: exp.HexString, binary_function_repr: t.Optional[str] = None + ) -> str: + this = self.sql(expression, "this") + is_integer_type = expression.args.get("is_integer") + + if (is_integer_type and not self.dialect.HEX_STRING_IS_INTEGER_TYPE) or ( + not self.dialect.HEX_START and not binary_function_repr + ): + # Integer representation will be returned if: + # - The read dialect treats the hex value as integer literal but not the write + # - The transpilation is not supported (write dialect hasn't set HEX_START or the param flag) + return f"{int(this, 16)}" + + if not is_integer_type: + # Read dialect treats the hex value as BINARY/BLOB + if binary_function_repr: + # The write dialect supports the transpilation to its equivalent BINARY/BLOB + return self.func(binary_function_repr, exp.Literal.string(this)) + if self.dialect.HEX_STRING_IS_INTEGER_TYPE: + # The write dialect does not support the transpilation, it'll treat the hex value as INTEGER + self.unsupported( + "Unsupported transpilation from BINARY/BLOB hex string" + ) + + return f"{self.dialect.HEX_START}{this}{self.dialect.HEX_END}" + + def bytestring_sql(self, expression: exp.ByteString) -> str: + this = self.sql(expression, "this") + if self.dialect.BYTE_START: + escaped_byte_string = self.escape_str( + this, + escape_backslash=False, + delimiter=self.dialect.BYTE_END, + escaped_delimiter=self._escaped_byte_quote_end, + ) + is_bytes = expression.args.get("is_bytes", False) + delimited_byte_string = ( + f"{self.dialect.BYTE_START}{escaped_byte_string}{self.dialect.BYTE_END}" + ) + if is_bytes and not self.dialect.BYTE_STRING_IS_BYTES_TYPE: + return self.sql( + exp.cast( + delimited_byte_string, + exp.DataType.Type.BINARY, + dialect=self.dialect, + ) + ) + if not is_bytes and self.dialect.BYTE_STRING_IS_BYTES_TYPE: + return self.sql( + exp.cast( + delimited_byte_string, + exp.DataType.Type.VARCHAR, + dialect=self.dialect, + ) + ) + + return delimited_byte_string + return this + + def unicodestring_sql(self, expression: exp.UnicodeString) -> str: + this = self.sql(expression, "this") + escape = expression.args.get("escape") + + if self.dialect.UNICODE_START: + escape_substitute = r"\\\1" + left_quote, right_quote = ( + self.dialect.UNICODE_START, + self.dialect.UNICODE_END, + ) + else: + escape_substitute = r"\\u\1" + left_quote, right_quote = self.dialect.QUOTE_START, self.dialect.QUOTE_END + + if escape: + escape_pattern = re.compile(rf"{escape.name}(\d+)") + escape_sql = f" UESCAPE {self.sql(escape)}" if self.SUPPORTS_UESCAPE else "" + else: + escape_pattern = ESCAPED_UNICODE_RE + escape_sql = "" + + if not self.dialect.UNICODE_START or (escape and not self.SUPPORTS_UESCAPE): + this = escape_pattern.sub( + self.UNICODE_SUBSTITUTE or escape_substitute, this + ) + + return f"{left_quote}{this}{right_quote}{escape_sql}" + + def rawstring_sql(self, expression: exp.RawString) -> str: + string = expression.this + if "\\" in self.dialect.tokenizer_class.STRING_ESCAPES: + string = string.replace("\\", "\\\\") + + string = self.escape_str(string, escape_backslash=False) + return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}" + + def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str: + this = self.sql(expression, "this") + specifier = self.sql(expression, "expression") + specifier = ( + f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else "" + ) + return f"{this}{specifier}" + + def datatype_sql(self, expression: exp.DataType) -> str: + nested = "" + values = "" + interior = self.expressions(expression, flat=True) + + type_value = expression.this + if type_value in self.UNSUPPORTED_TYPES: + self.unsupported( + f"Data type {type_value.value} is not supported when targeting {self.dialect.__class__.__name__}" + ) + + if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"): + type_sql = self.sql(expression, "kind") + else: + type_sql = ( + self.TYPE_MAPPING.get(type_value, type_value.value) + if isinstance(type_value, exp.DataType.Type) + else type_value + ) + + if interior: + if expression.args.get("nested"): + nested = ( + f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" + ) + if expression.args.get("values") is not None: + delimiters = ( + ("[", "]") + if type_value == exp.DataType.Type.ARRAY + else ("(", ")") + ) + values = self.expressions(expression, key="values", flat=True) + values = f"{delimiters[0]}{values}{delimiters[1]}" + elif type_value == exp.DataType.Type.INTERVAL: + nested = f" {interior}" + else: + nested = f"({interior})" + + type_sql = f"{type_sql}{nested}{values}" + if self.TZ_TO_WITH_TIME_ZONE and type_value in ( + exp.DataType.Type.TIMETZ, + exp.DataType.Type.TIMESTAMPTZ, + ): + type_sql = f"{type_sql} WITH TIME ZONE" + + return type_sql + + def directory_sql(self, expression: exp.Directory) -> str: + local = "LOCAL " if expression.args.get("local") else "" + row_format = self.sql(expression, "row_format") + row_format = f" {row_format}" if row_format else "" + return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}" + + def delete_sql(self, expression: exp.Delete) -> str: + this = self.sql(expression, "this") + this = f" FROM {this}" if this else "" + using = self.expressions(expression, key="using") + using = f" USING {using}" if using else "" + cluster = self.sql(expression, "cluster") + cluster = f" {cluster}" if cluster else "" + where = self.sql(expression, "where") + returning = self.sql(expression, "returning") + order = self.sql(expression, "order") + limit = self.sql(expression, "limit") + tables = self.expressions(expression, key="tables") + tables = f" {tables}" if tables else "" + if self.RETURNING_END: + expression_sql = f"{this}{using}{cluster}{where}{returning}{order}{limit}" + else: + expression_sql = f"{returning}{this}{using}{cluster}{where}{order}{limit}" + return self.prepend_ctes(expression, f"DELETE{tables}{expression_sql}") + + def drop_sql(self, expression: exp.Drop) -> str: + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + expressions = f" ({expressions})" if expressions else "" + kind = expression.args["kind"] + kind = self.dialect.INVERSE_CREATABLE_KIND_MAPPING.get(kind) or kind + exists_sql = " IF EXISTS " if expression.args.get("exists") else " " + concurrently_sql = ( + " CONCURRENTLY" if expression.args.get("concurrently") else "" + ) + on_cluster = self.sql(expression, "cluster") + on_cluster = f" {on_cluster}" if on_cluster else "" + temporary = " TEMPORARY" if expression.args.get("temporary") else "" + materialized = " MATERIALIZED" if expression.args.get("materialized") else "" + cascade = " CASCADE" if expression.args.get("cascade") else "" + constraints = " CONSTRAINTS" if expression.args.get("constraints") else "" + purge = " PURGE" if expression.args.get("purge") else "" + return f"DROP{temporary}{materialized} {kind}{concurrently_sql}{exists_sql}{this}{on_cluster}{expressions}{cascade}{constraints}{purge}" + + def set_operation(self, expression: exp.SetOperation) -> str: + op_type = type(expression) + op_name = op_type.key.upper() + + distinct = expression.args.get("distinct") + if ( + distinct is False + and op_type in (exp.Except, exp.Intersect) + and not self.EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE + ): + self.unsupported(f"{op_name} ALL is not supported") + + default_distinct = self.dialect.SET_OP_DISTINCT_BY_DEFAULT[op_type] + + if distinct is None: + distinct = default_distinct + if distinct is None: + self.unsupported(f"{op_name} requires DISTINCT or ALL to be specified") + + if distinct is default_distinct: + distinct_or_all = "" + else: + distinct_or_all = " DISTINCT" if distinct else " ALL" + + side_kind = " ".join(filter(None, [expression.side, expression.kind])) + side_kind = f"{side_kind} " if side_kind else "" + + by_name = " BY NAME" if expression.args.get("by_name") else "" + on = self.expressions(expression, key="on", flat=True) + on = f" ON ({on})" if on else "" + + return f"{side_kind}{op_name}{distinct_or_all}{by_name}{on}" + + def set_operations(self, expression: exp.SetOperation) -> str: + if not self.SET_OP_MODIFIERS: + limit = expression.args.get("limit") + order = expression.args.get("order") + + if limit or order: + select = self._move_ctes_to_top_level( + exp.subquery(expression, "_l_0", copy=False).select("*", copy=False) + ) + + if limit: + select = select.limit(limit.pop(), copy=False) + if order: + select = select.order_by(order.pop(), copy=False) + return self.sql(select) + + sqls: t.List[str] = [] + stack: t.List[t.Union[str, exp.Expression]] = [expression] + + while stack: + node = stack.pop() + + if isinstance(node, exp.SetOperation): + stack.append(node.expression) + stack.append( + self.maybe_comment( + self.set_operation(node), comments=node.comments, separated=True + ) + ) + stack.append(node.this) + else: + sqls.append(self.sql(node)) + + this = self.sep().join(sqls) + this = self.query_modifiers(expression, this) + return self.prepend_ctes(expression, this) + + def fetch_sql(self, expression: exp.Fetch) -> str: + direction = expression.args.get("direction") + direction = f" {direction}" if direction else "" + count = self.sql(expression, "count") + count = f" {count}" if count else "" + limit_options = self.sql(expression, "limit_options") + limit_options = f"{limit_options}" if limit_options else " ROWS ONLY" + return f"{self.seg('FETCH')}{direction}{count}{limit_options}" + + def limitoptions_sql(self, expression: exp.LimitOptions) -> str: + percent = " PERCENT" if expression.args.get("percent") else "" + rows = " ROWS" if expression.args.get("rows") else "" + with_ties = " WITH TIES" if expression.args.get("with_ties") else "" + if not with_ties and rows: + with_ties = " ONLY" + return f"{percent}{rows}{with_ties}" + + def filter_sql(self, expression: exp.Filter) -> str: + if self.AGGREGATE_FILTER_SUPPORTED: + this = self.sql(expression, "this") + where = self.sql(expression, "expression").strip() + return f"{this} FILTER({where})" + + agg = expression.this + agg_arg = agg.this + cond = expression.expression.this + agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy())) + return self.sql(agg) + + def hint_sql(self, expression: exp.Hint) -> str: + if not self.QUERY_HINTS: + self.unsupported("Hints are not supported") + return "" + + return ( + f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */" + ) + + def indexparameters_sql(self, expression: exp.IndexParameters) -> str: + using = self.sql(expression, "using") + using = f" USING {using}" if using else "" + columns = self.expressions(expression, key="columns", flat=True) + columns = f"({columns})" if columns else "" + partition_by = self.expressions(expression, key="partition_by", flat=True) + partition_by = f" PARTITION BY {partition_by}" if partition_by else "" + where = self.sql(expression, "where") + include = self.expressions(expression, key="include", flat=True) + if include: + include = f" INCLUDE ({include})" + with_storage = self.expressions(expression, key="with_storage", flat=True) + with_storage = f" WITH ({with_storage})" if with_storage else "" + tablespace = self.sql(expression, "tablespace") + tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else "" + on = self.sql(expression, "on") + on = f" ON {on}" if on else "" + + return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}{on}" + + def index_sql(self, expression: exp.Index) -> str: + unique = "UNIQUE " if expression.args.get("unique") else "" + primary = "PRIMARY " if expression.args.get("primary") else "" + amp = "AMP " if expression.args.get("amp") else "" + name = self.sql(expression, "this") + name = f"{name} " if name else "" + table = self.sql(expression, "table") + table = f"{self.INDEX_ON} {table}" if table else "" + + index = "INDEX " if not table else "" + + params = self.sql(expression, "params") + return f"{unique}{primary}{amp}{index}{name}{table}{params}" + + def identifier_sql(self, expression: exp.Identifier) -> str: + text = expression.name + lower = text.lower() + text = lower if self.normalize and not expression.quoted else text + text = text.replace(self._identifier_end, self._escaped_identifier_end) + if ( + expression.quoted + or self.dialect.can_quote(expression, self.identify) + or lower in self.RESERVED_KEYWORDS + or ( + not self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit() + ) + ): + text = f"{self._identifier_start}{text}{self._identifier_end}" + return text + + def hex_sql(self, expression: exp.Hex) -> str: + text = self.func(self.HEX_FUNC, self.sql(expression, "this")) + if self.dialect.HEX_LOWERCASE: + text = self.func("LOWER", text) + + return text + + def lowerhex_sql(self, expression: exp.LowerHex) -> str: + text = self.func(self.HEX_FUNC, self.sql(expression, "this")) + if not self.dialect.HEX_LOWERCASE: + text = self.func("LOWER", text) + return text + + def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str: + input_format = self.sql(expression, "input_format") + input_format = f"INPUTFORMAT {input_format}" if input_format else "" + output_format = self.sql(expression, "output_format") + output_format = f"OUTPUTFORMAT {output_format}" if output_format else "" + return self.sep().join((input_format, output_format)) + + def national_sql(self, expression: exp.National, prefix: str = "N") -> str: + string = self.sql(exp.Literal.string(expression.name)) + return f"{prefix}{string}" + + def partition_sql(self, expression: exp.Partition) -> str: + partition_keyword = ( + "SUBPARTITION" if expression.args.get("subpartition") else "PARTITION" + ) + return f"{partition_keyword}({self.expressions(expression, flat=True)})" + + def properties_sql(self, expression: exp.Properties) -> str: + root_properties = [] + with_properties = [] + + for p in expression.expressions: + p_loc = self.PROPERTIES_LOCATION[p.__class__] + if p_loc == exp.Properties.Location.POST_WITH: + with_properties.append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA: + root_properties.append(p) + + root_props_ast = exp.Properties(expressions=root_properties) + root_props_ast.parent = expression.parent + + with_props_ast = exp.Properties(expressions=with_properties) + with_props_ast.parent = expression.parent + + root_props = self.root_properties(root_props_ast) + with_props = self.with_properties(with_props_ast) + + if root_props and with_props and not self.pretty: + with_props = " " + with_props + + return root_props + with_props + + def root_properties(self, properties: exp.Properties) -> str: + if properties.expressions: + return self.expressions(properties, indent=False, sep=" ") + return "" + + def properties( + self, + properties: exp.Properties, + prefix: str = "", + sep: str = ", ", + suffix: str = "", + wrapped: bool = True, + ) -> str: + if properties.expressions: + expressions = self.expressions(properties, sep=sep, indent=False) + if expressions: + expressions = self.wrap(expressions) if wrapped else expressions + return f"{prefix}{' ' if prefix.strip() else ''}{expressions}{suffix}" + return "" + + def with_properties(self, properties: exp.Properties) -> str: + return self.properties( + properties, prefix=self.seg(self.WITH_PROPERTIES_PREFIX, sep="") + ) + + def locate_properties(self, properties: exp.Properties) -> t.DefaultDict: + properties_locs = defaultdict(list) + for p in properties.expressions: + p_loc = self.PROPERTIES_LOCATION[p.__class__] + if p_loc != exp.Properties.Location.UNSUPPORTED: + properties_locs[p_loc].append(p) + else: + self.unsupported(f"Unsupported property {p.key}") + + return properties_locs + + def property_name(self, expression: exp.Property, string_key: bool = False) -> str: + if isinstance(expression.this, exp.Dot): + return self.sql(expression, "this") + return f"'{expression.name}'" if string_key else expression.name + + def property_sql(self, expression: exp.Property) -> str: + property_cls = expression.__class__ + if property_cls == exp.Property: + return f"{self.property_name(expression)}={self.sql(expression, 'value')}" + + property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) + if not property_name: + self.unsupported(f"Unsupported property {expression.key}") + + return f"{property_name}={self.sql(expression, 'this')}" + + def likeproperty_sql(self, expression: exp.LikeProperty) -> str: + if self.SUPPORTS_CREATE_TABLE_LIKE: + options = " ".join( + f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions + ) + options = f" {options}" if options else "" + + like = f"LIKE {self.sql(expression, 'this')}{options}" + if self.LIKE_PROPERTY_INSIDE_SCHEMA and not isinstance( + expression.parent, exp.Schema + ): + like = f"({like})" + + return like + + if expression.expressions: + self.unsupported("Transpilation of LIKE property options is unsupported") + + select = exp.select("*").from_(expression.this).limit(0) + return f"AS {self.sql(select)}" + + def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str: + no = "NO " if expression.args.get("no") else "" + protection = " PROTECTION" if expression.args.get("protection") else "" + return f"{no}FALLBACK{protection}" + + def journalproperty_sql(self, expression: exp.JournalProperty) -> str: + no = "NO " if expression.args.get("no") else "" + local = expression.args.get("local") + local = f"{local} " if local else "" + dual = "DUAL " if expression.args.get("dual") else "" + before = "BEFORE " if expression.args.get("before") else "" + after = "AFTER " if expression.args.get("after") else "" + return f"{no}{local}{dual}{before}{after}JOURNAL" + + def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str: + freespace = self.sql(expression, "this") + percent = " PERCENT" if expression.args.get("percent") else "" + return f"FREESPACE={freespace}{percent}" + + def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str: + if expression.args.get("default"): + property = "DEFAULT" + elif expression.args.get("on"): + property = "ON" + else: + property = "OFF" + return f"CHECKSUM={property}" + + def mergeblockratioproperty_sql( + self, expression: exp.MergeBlockRatioProperty + ) -> str: + if expression.args.get("no"): + return "NO MERGEBLOCKRATIO" + if expression.args.get("default"): + return "DEFAULT MERGEBLOCKRATIO" + + percent = " PERCENT" if expression.args.get("percent") else "" + return f"MERGEBLOCKRATIO={self.sql(expression, 'this')}{percent}" + + def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str: + default = expression.args.get("default") + minimum = expression.args.get("minimum") + maximum = expression.args.get("maximum") + if default or minimum or maximum: + if default: + prop = "DEFAULT" + elif minimum: + prop = "MINIMUM" + else: + prop = "MAXIMUM" + return f"{prop} DATABLOCKSIZE" + units = expression.args.get("units") + units = f" {units}" if units else "" + return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}" + + def blockcompressionproperty_sql( + self, expression: exp.BlockCompressionProperty + ) -> str: + autotemp = expression.args.get("autotemp") + always = expression.args.get("always") + default = expression.args.get("default") + manual = expression.args.get("manual") + never = expression.args.get("never") + + if autotemp is not None: + prop = f"AUTOTEMP({self.expressions(autotemp)})" + elif always: + prop = "ALWAYS" + elif default: + prop = "DEFAULT" + elif manual: + prop = "MANUAL" + elif never: + prop = "NEVER" + return f"BLOCKCOMPRESSION={prop}" + + def isolatedloadingproperty_sql( + self, expression: exp.IsolatedLoadingProperty + ) -> str: + no = expression.args.get("no") + no = " NO" if no else "" + concurrent = expression.args.get("concurrent") + concurrent = " CONCURRENT" if concurrent else "" + target = self.sql(expression, "target") + target = f" {target}" if target else "" + return f"WITH{no}{concurrent} ISOLATED LOADING{target}" + + def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str: + if isinstance(expression.this, list): + return f"IN ({self.expressions(expression, key='this', flat=True)})" + if expression.this: + modulus = self.sql(expression, "this") + remainder = self.sql(expression, "expression") + return f"WITH (MODULUS {modulus}, REMAINDER {remainder})" + + from_expressions = self.expressions( + expression, key="from_expressions", flat=True + ) + to_expressions = self.expressions(expression, key="to_expressions", flat=True) + return f"FROM ({from_expressions}) TO ({to_expressions})" + + def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str: + this = self.sql(expression, "this") + + for_values_or_default = expression.expression + if isinstance(for_values_or_default, exp.PartitionBoundSpec): + for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}" + else: + for_values_or_default = " DEFAULT" + + return f"PARTITION OF {this}{for_values_or_default}" + + def lockingproperty_sql(self, expression: exp.LockingProperty) -> str: + kind = expression.args.get("kind") + this = f" {self.sql(expression, 'this')}" if expression.this else "" + for_or_in = expression.args.get("for_or_in") + for_or_in = f" {for_or_in}" if for_or_in else "" + lock_type = expression.args.get("lock_type") + override = " OVERRIDE" if expression.args.get("override") else "" + return f"LOCKING {kind}{this}{for_or_in} {lock_type}{override}" + + def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str: + data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA" + statistics = expression.args.get("statistics") + statistics_sql = "" + if statistics is not None: + statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS" + return f"{data_sql}{statistics_sql}" + + def withsystemversioningproperty_sql( + self, expression: exp.WithSystemVersioningProperty + ) -> str: + this = self.sql(expression, "this") + this = f"HISTORY_TABLE={this}" if this else "" + data_consistency: t.Optional[str] = self.sql(expression, "data_consistency") + data_consistency = ( + f"DATA_CONSISTENCY_CHECK={data_consistency}" if data_consistency else None + ) + retention_period: t.Optional[str] = self.sql(expression, "retention_period") + retention_period = ( + f"HISTORY_RETENTION_PERIOD={retention_period}" if retention_period else None + ) + + if this: + on_sql = self.func("ON", this, data_consistency, retention_period) + else: + on_sql = "ON" if expression.args.get("on") else "OFF" + + sql = f"SYSTEM_VERSIONING={on_sql}" + + return f"WITH({sql})" if expression.args.get("with_") else sql + + def insert_sql(self, expression: exp.Insert) -> str: + hint = self.sql(expression, "hint") + overwrite = expression.args.get("overwrite") + + if isinstance(expression.this, exp.Directory): + this = " OVERWRITE" if overwrite else " INTO" + else: + this = self.INSERT_OVERWRITE if overwrite else " INTO" + + stored = self.sql(expression, "stored") + stored = f" {stored}" if stored else "" + alternative = expression.args.get("alternative") + alternative = f" OR {alternative}" if alternative else "" + ignore = " IGNORE" if expression.args.get("ignore") else "" + is_function = expression.args.get("is_function") + if is_function: + this = f"{this} FUNCTION" + this = f"{this} {self.sql(expression, 'this')}" + + exists = " IF EXISTS" if expression.args.get("exists") else "" + where = self.sql(expression, "where") + where = f"{self.sep()}REPLACE WHERE {where}" if where else "" + expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}" + on_conflict = self.sql(expression, "conflict") + on_conflict = f" {on_conflict}" if on_conflict else "" + by_name = " BY NAME" if expression.args.get("by_name") else "" + default_values = "DEFAULT VALUES" if expression.args.get("default") else "" + returning = self.sql(expression, "returning") + + if self.RETURNING_END: + expression_sql = f"{expression_sql}{on_conflict}{default_values}{returning}" + else: + expression_sql = f"{returning}{expression_sql}{on_conflict}" + + partition_by = self.sql(expression, "partition") + partition_by = f" {partition_by}" if partition_by else "" + settings = self.sql(expression, "settings") + settings = f" {settings}" if settings else "" + + source = self.sql(expression, "source") + source = f"TABLE {source}" if source else "" + + sql = f"INSERT{hint}{alternative}{ignore}{this}{stored}{by_name}{exists}{partition_by}{settings}{where}{expression_sql}{source}" + return self.prepend_ctes(expression, sql) + + def introducer_sql(self, expression: exp.Introducer) -> str: + return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + + def kill_sql(self, expression: exp.Kill) -> str: + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"KILL{kind}{this}" + + def pseudotype_sql(self, expression: exp.PseudoType) -> str: + return expression.name + + def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str: + return expression.name + + def onconflict_sql(self, expression: exp.OnConflict) -> str: + conflict = ( + "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT" + ) + + constraint = self.sql(expression, "constraint") + constraint = f" ON CONSTRAINT {constraint}" if constraint else "" + + conflict_keys = self.expressions(expression, key="conflict_keys", flat=True) + conflict_keys = f"({conflict_keys}) " if conflict_keys else " " + action = self.sql(expression, "action") + + expressions = self.expressions(expression, flat=True) + if expressions: + set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else "" + expressions = f" {set_keyword}{expressions}" + + where = self.sql(expression, "where") + return f"{conflict}{constraint}{conflict_keys}{action}{expressions}{where}" + + def returning_sql(self, expression: exp.Returning) -> str: + return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}" + + def rowformatdelimitedproperty_sql( + self, expression: exp.RowFormatDelimitedProperty + ) -> str: + fields = self.sql(expression, "fields") + fields = f" FIELDS TERMINATED BY {fields}" if fields else "" + escaped = self.sql(expression, "escaped") + escaped = f" ESCAPED BY {escaped}" if escaped else "" + items = self.sql(expression, "collection_items") + items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else "" + keys = self.sql(expression, "map_keys") + keys = f" MAP KEYS TERMINATED BY {keys}" if keys else "" + lines = self.sql(expression, "lines") + lines = f" LINES TERMINATED BY {lines}" if lines else "" + null = self.sql(expression, "null") + null = f" NULL DEFINED AS {null}" if null else "" + return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" + + def withtablehint_sql(self, expression: exp.WithTableHint) -> str: + return f"WITH ({self.expressions(expression, flat=True)})" + + def indextablehint_sql(self, expression: exp.IndexTableHint) -> str: + this = f"{self.sql(expression, 'this')} INDEX" + target = self.sql(expression, "target") + target = f" FOR {target}" if target else "" + return f"{this}{target} ({self.expressions(expression, flat=True)})" + + def historicaldata_sql(self, expression: exp.HistoricalData) -> str: + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + expr = self.sql(expression, "expression") + return f"{this} ({kind} => {expr})" + + def table_parts(self, expression: exp.Table) -> str: + return ".".join( + self.sql(part) + for part in ( + expression.args.get("catalog"), + expression.args.get("db"), + expression.args.get("this"), + ) + if part is not None + ) + + def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: + table = self.table_parts(expression) + only = "ONLY " if expression.args.get("only") else "" + partition = self.sql(expression, "partition") + partition = f" {partition}" if partition else "" + version = self.sql(expression, "version") + version = f" {version}" if version else "" + alias = self.sql(expression, "alias") + alias = f"{sep}{alias}" if alias else "" + + sample = self.sql(expression, "sample") + if self.dialect.ALIAS_POST_TABLESAMPLE: + sample_pre_alias = sample + sample_post_alias = "" + else: + sample_pre_alias = "" + sample_post_alias = sample + + hints = self.expressions(expression, key="hints", sep=" ") + hints = f" {hints}" if hints and self.TABLE_HINTS else "" + pivots = self.expressions(expression, key="pivots", sep="", flat=True) + joins = self.indent( + self.expressions(expression, key="joins", sep="", flat=True), + skip_first=True, + ) + laterals = self.expressions(expression, key="laterals", sep="") + + file_format = self.sql(expression, "format") + if file_format: + pattern = self.sql(expression, "pattern") + pattern = f", PATTERN => {pattern}" if pattern else "" + file_format = f" (FILE_FORMAT => {file_format}{pattern})" + + ordinality = expression.args.get("ordinality") or "" + if ordinality: + ordinality = f" WITH ORDINALITY{alias}" + alias = "" + + when = self.sql(expression, "when") + if when: + table = f"{table} {when}" + + changes = self.sql(expression, "changes") + changes = f" {changes}" if changes else "" + + rows_from = self.expressions(expression, key="rows_from") + if rows_from: + table = f"ROWS FROM {self.wrap(rows_from)}" + + indexed = expression.args.get("indexed") + if indexed is not None: + indexed = f" INDEXED BY {self.sql(indexed)}" if indexed else " NOT INDEXED" + else: + indexed = "" + + return f"{only}{table}{changes}{partition}{version}{file_format}{sample_pre_alias}{alias}{indexed}{hints}{pivots}{sample_post_alias}{joins}{laterals}{ordinality}" + + def tablefromrows_sql(self, expression: exp.TableFromRows) -> str: + table = self.func("TABLE", expression.this) + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + sample = self.sql(expression, "sample") + pivots = self.expressions(expression, key="pivots", sep="", flat=True) + joins = self.indent( + self.expressions(expression, key="joins", sep="", flat=True), + skip_first=True, + ) + return f"{table}{alias}{pivots}{sample}{joins}" + + def tablesample_sql( + self, + expression: exp.TableSample, + tablesample_keyword: t.Optional[str] = None, + ) -> str: + method = self.sql(expression, "method") + method = f"{method} " if method and self.TABLESAMPLE_WITH_METHOD else "" + numerator = self.sql(expression, "bucket_numerator") + denominator = self.sql(expression, "bucket_denominator") + field = self.sql(expression, "bucket_field") + field = f" ON {field}" if field else "" + bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else "" + seed = self.sql(expression, "seed") + seed = f" {self.TABLESAMPLE_SEED_KEYWORD} ({seed})" if seed else "" + + size = self.sql(expression, "size") + if size and self.TABLESAMPLE_SIZE_IS_ROWS: + size = f"{size} ROWS" + + percent = self.sql(expression, "percent") + if percent and not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT: + percent = f"{percent} PERCENT" + + expr = f"{bucket}{percent}{size}" + if self.TABLESAMPLE_REQUIRES_PARENS: + expr = f"({expr})" + + return ( + f" {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}" + ) + + def pivot_sql(self, expression: exp.Pivot) -> str: + expressions = self.expressions(expression, flat=True) + direction = "UNPIVOT" if expression.unpivot else "PIVOT" + + group = self.sql(expression, "group") + + if expression.this: + this = self.sql(expression, "this") + if not expressions: + sql = f"UNPIVOT {this}" + else: + on = f"{self.seg('ON')} {expressions}" + into = self.sql(expression, "into") + into = f"{self.seg('INTO')} {into}" if into else "" + using = self.expressions(expression, key="using", flat=True) + using = f"{self.seg('USING')} {using}" if using else "" + sql = f"{direction} {this}{on}{into}{using}{group}" + return self.prepend_ctes(expression, sql) + + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + + fields = self.expressions( + expression, + "fields", + sep=" ", + dynamic=True, + new_line=True, + skip_first=True, + skip_last=True, + ) + + include_nulls = expression.args.get("include_nulls") + if include_nulls is not None: + nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS " + else: + nulls = "" + + default_on_null = self.sql(expression, "default_on_null") + default_on_null = ( + f" DEFAULT ON NULL ({default_on_null})" if default_on_null else "" + ) + sql = f"{self.seg(direction)}{nulls}({expressions} FOR {fields}{default_on_null}{group}){alias}" + return self.prepend_ctes(expression, sql) + + def version_sql(self, expression: exp.Version) -> str: + this = f"FOR {expression.name}" + kind = expression.text("kind") + expr = self.sql(expression, "expression") + return f"{this} {kind} {expr}" + + def tuple_sql(self, expression: exp.Tuple) -> str: + return f"({self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)})" + + def update_sql(self, expression: exp.Update) -> str: + this = self.sql(expression, "this") + set_sql = self.expressions(expression, flat=True) + from_sql = self.sql(expression, "from_") + where_sql = self.sql(expression, "where") + returning = self.sql(expression, "returning") + order = self.sql(expression, "order") + limit = self.sql(expression, "limit") + if self.RETURNING_END: + expression_sql = f"{from_sql}{where_sql}{returning}" + else: + expression_sql = f"{returning}{from_sql}{where_sql}" + options = self.expressions(expression, key="options") + options = f" OPTION({options})" if options else "" + sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}{options}" + return self.prepend_ctes(expression, sql) + + def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: + values_as_table = values_as_table and self.VALUES_AS_TABLE + + # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example + if values_as_table or not expression.find_ancestor(exp.From, exp.Join): + args = self.expressions(expression) + alias = self.sql(expression, "alias") + values = f"VALUES{self.seg('')}{args}" + values = ( + f"({values})" + if self.WRAP_DERIVED_VALUES + and (alias or isinstance(expression.parent, (exp.From, exp.Table))) + else values + ) + values = self.query_modifiers(expression, values) + return f"{values} AS {alias}" if alias else values + + # Converts `VALUES...` expression into a series of select unions. + alias_node = expression.args.get("alias") + column_names = alias_node and alias_node.columns + + selects: t.List[exp.Query] = [] + + for i, tup in enumerate(expression.expressions): + row = tup.expressions + + if i == 0 and column_names: + row = [ + exp.alias_(value, column_name) + for value, column_name in zip(row, column_names) + ] + + selects.append(exp.Select(expressions=row)) + + if self.pretty: + # This may result in poor performance for large-cardinality `VALUES` tables, due to + # the deep nesting of the resulting exp.Unions. If this is a problem, either increase + # `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`. + query = reduce( + lambda x, y: exp.union(x, y, distinct=False, copy=False), selects + ) + return self.subquery_sql( + query.subquery(alias_node and alias_node.this, copy=False) + ) + + alias = f" AS {self.sql(alias_node, 'this')}" if alias_node else "" + unions = " UNION ALL ".join(self.sql(select) for select in selects) + return f"({unions}){alias}" + + def var_sql(self, expression: exp.Var) -> str: + return self.sql(expression, "this") + + @unsupported_args("expressions") + def into_sql(self, expression: exp.Into) -> str: + temporary = " TEMPORARY" if expression.args.get("temporary") else "" + unlogged = " UNLOGGED" if expression.args.get("unlogged") else "" + return ( + f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" + ) + + def from_sql(self, expression: exp.From) -> str: + return f"{self.seg('FROM')} {self.sql(expression, 'this')}" + + def groupingsets_sql(self, expression: exp.GroupingSets) -> str: + grouping_sets = self.expressions(expression, indent=False) + return f"GROUPING SETS {self.wrap(grouping_sets)}" + + def rollup_sql(self, expression: exp.Rollup) -> str: + expressions = self.expressions(expression, indent=False) + return f"ROLLUP {self.wrap(expressions)}" if expressions else "WITH ROLLUP" + + def cube_sql(self, expression: exp.Cube) -> str: + expressions = self.expressions(expression, indent=False) + return f"CUBE {self.wrap(expressions)}" if expressions else "WITH CUBE" + + def group_sql(self, expression: exp.Group) -> str: + group_by_all = expression.args.get("all") + if group_by_all is True: + modifier = " ALL" + elif group_by_all is False: + modifier = " DISTINCT" + else: + modifier = "" + + group_by = self.op_expressions(f"GROUP BY{modifier}", expression) + + grouping_sets = self.expressions(expression, key="grouping_sets") + cube = self.expressions(expression, key="cube") + rollup = self.expressions(expression, key="rollup") + + groupings = csv( + self.seg(grouping_sets) if grouping_sets else "", + self.seg(cube) if cube else "", + self.seg(rollup) if rollup else "", + self.seg("WITH TOTALS") if expression.args.get("totals") else "", + sep=self.GROUPINGS_SEP, + ) + + if ( + expression.expressions + and groupings + and groupings.strip() not in ("WITH CUBE", "WITH ROLLUP") + ): + group_by = f"{group_by}{self.GROUPINGS_SEP}" + + return f"{group_by}{groupings}" + + def having_sql(self, expression: exp.Having) -> str: + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('HAVING')}{self.sep()}{this}" + + def connect_sql(self, expression: exp.Connect) -> str: + start = self.sql(expression, "start") + start = self.seg(f"START WITH {start}") if start else "" + nocycle = " NOCYCLE" if expression.args.get("nocycle") else "" + connect = self.sql(expression, "connect") + connect = self.seg(f"CONNECT BY{nocycle} {connect}") + return start + connect + + def prior_sql(self, expression: exp.Prior) -> str: + return f"PRIOR {self.sql(expression, 'this')}" + + def join_sql(self, expression: exp.Join) -> str: + if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in ("SEMI", "ANTI"): + side = None + else: + side = expression.side + + op_sql = " ".join( + op + for op in ( + expression.method, + "GLOBAL" if expression.args.get("global_") else None, + side, + expression.kind, + expression.hint if self.JOIN_HINTS else None, + ) + if op + ) + match_cond = self.sql(expression, "match_condition") + match_cond = f" MATCH_CONDITION ({match_cond})" if match_cond else "" + on_sql = self.sql(expression, "on") + using = expression.args.get("using") + + if not on_sql and using: + on_sql = csv(*(self.sql(column) for column in using)) + + this = expression.this + this_sql = self.sql(this) + + exprs = self.expressions(expression) + if exprs: + this_sql = f"{this_sql},{self.seg(exprs)}" + + if on_sql: + on_sql = self.indent(on_sql, skip_first=True) + space = self.seg(" " * self.pad) if self.pretty else " " + if using: + on_sql = f"{space}USING ({on_sql})" + else: + on_sql = f"{space}ON {on_sql}" + elif not op_sql: + if ( + isinstance(this, exp.Lateral) + and this.args.get("cross_apply") is not None + ): + return f" {this_sql}" + + return f", {this_sql}" + + if op_sql != "STRAIGHT_JOIN": + op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" + + pivots = self.expressions(expression, key="pivots", sep="", flat=True) + return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}{pivots}" + + def lambda_sql( + self, expression: exp.Lambda, arrow_sep: str = "->", wrap: bool = True + ) -> str: + args = self.expressions(expression, flat=True) + args = f"({args})" if wrap and len(args.split(",")) > 1 else args + return f"{args} {arrow_sep} {self.sql(expression, 'this')}" + + def lateral_op(self, expression: exp.Lateral) -> str: + cross_apply = expression.args.get("cross_apply") + + # https://www.mssqltips.com/sqlservertip/1958/sql-server-cross-apply-and-outer-apply/ + if cross_apply is True: + op = "INNER JOIN " + elif cross_apply is False: + op = "LEFT JOIN " + else: + op = "" + + return f"{op}LATERAL" + + def lateral_sql(self, expression: exp.Lateral) -> str: + this = self.sql(expression, "this") + + if expression.args.get("view"): + alias = expression.args["alias"] + columns = self.expressions(alias, key="columns", flat=True) + table = f" {alias.name}" if alias.name else "" + columns = f" AS {columns}" if columns else "" + op_sql = self.seg( + f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}" + ) + return f"{op_sql}{self.sep()}{this}{table}{columns}" + + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + + ordinality = expression.args.get("ordinality") or "" + if ordinality: + ordinality = f" WITH ORDINALITY{alias}" + alias = "" + + return f"{self.lateral_op(expression)} {this}{alias}{ordinality}" + + def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: + this = self.sql(expression, "this") + + args = [ + self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e + for e in (expression.args.get(k) for k in ("offset", "expression")) + if e + ] + + args_sql = ", ".join(self.sql(e) for e in args) + args_sql = ( + f"({args_sql})" if top and any(not e.is_number for e in args) else args_sql + ) + expressions = self.expressions(expression, flat=True) + limit_options = self.sql(expression, "limit_options") + expressions = f" BY {expressions}" if expressions else "" + + return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}{limit_options}{expressions}" + + def offset_sql(self, expression: exp.Offset) -> str: + this = self.sql(expression, "this") + value = expression.expression + value = ( + self._simplify_unless_literal(value) if self.LIMIT_ONLY_LITERALS else value + ) + expressions = self.expressions(expression, flat=True) + expressions = f" BY {expressions}" if expressions else "" + return f"{this}{self.seg('OFFSET')} {self.sql(value)}{expressions}" + + def setitem_sql(self, expression: exp.SetItem) -> str: + kind = self.sql(expression, "kind") + if not self.SET_ASSIGNMENT_REQUIRES_VARIABLE_KEYWORD and kind == "VARIABLE": + kind = "" + else: + kind = f"{kind} " if kind else "" + this = self.sql(expression, "this") + expressions = self.expressions(expression) + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + global_ = "GLOBAL " if expression.args.get("global_") else "" + return f"{global_}{kind}{this}{expressions}{collate}" + + def set_sql(self, expression: exp.Set) -> str: + expressions = f" {self.expressions(expression, flat=True)}" + tag = " TAG" if expression.args.get("tag") else "" + return f"{'UNSET' if expression.args.get('unset') else 'SET'}{tag}{expressions}" + + def queryband_sql(self, expression: exp.QueryBand) -> str: + this = self.sql(expression, "this") + update = " UPDATE" if expression.args.get("update") else "" + scope = self.sql(expression, "scope") + scope = f" FOR {scope}" if scope else "" + + return f"QUERY_BAND = {this}{update}{scope}" + + def pragma_sql(self, expression: exp.Pragma) -> str: + return f"PRAGMA {self.sql(expression, 'this')}" + + def lock_sql(self, expression: exp.Lock) -> str: + if not self.LOCKING_READS_SUPPORTED: + self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") + return "" + + update = expression.args["update"] + key = expression.args.get("key") + if update: + lock_type = "FOR NO KEY UPDATE" if key else "FOR UPDATE" + else: + lock_type = "FOR KEY SHARE" if key else "FOR SHARE" + expressions = self.expressions(expression, flat=True) + expressions = f" OF {expressions}" if expressions else "" + wait = expression.args.get("wait") + + if wait is not None: + if isinstance(wait, exp.Literal): + wait = f" WAIT {self.sql(wait)}" + else: + wait = " NOWAIT" if wait else " SKIP LOCKED" + + return f"{lock_type}{expressions}{wait or ''}" + + def literal_sql(self, expression: exp.Literal) -> str: + text = expression.this or "" + if expression.is_string: + text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}" + return text + + def escape_str( + self, + text: str, + escape_backslash: bool = True, + delimiter: t.Optional[str] = None, + escaped_delimiter: t.Optional[str] = None, + ) -> str: + if self.dialect.ESCAPED_SEQUENCES: + to_escaped = self.dialect.ESCAPED_SEQUENCES + text = "".join( + to_escaped.get(ch, ch) if escape_backslash or ch != "\\" else ch + for ch in text + ) + + delimiter = delimiter or self.dialect.QUOTE_END + escaped_delimiter = escaped_delimiter or self._escaped_quote_end + + return self._replace_line_breaks(text).replace(delimiter, escaped_delimiter) + + def loaddata_sql(self, expression: exp.LoadData) -> str: + local = " LOCAL" if expression.args.get("local") else "" + inpath = f" INPATH {self.sql(expression, 'inpath')}" + overwrite = " OVERWRITE" if expression.args.get("overwrite") else "" + this = f" INTO TABLE {self.sql(expression, 'this')}" + partition = self.sql(expression, "partition") + partition = f" {partition}" if partition else "" + input_format = self.sql(expression, "input_format") + input_format = f" INPUTFORMAT {input_format}" if input_format else "" + serde = self.sql(expression, "serde") + serde = f" SERDE {serde}" if serde else "" + return ( + f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}" + ) + + def null_sql(self, *_) -> str: + return "NULL" + + def boolean_sql(self, expression: exp.Boolean) -> str: + return "TRUE" if expression.this else "FALSE" + + def booland_sql(self, expression: exp.Booland) -> str: + return f"(({self.sql(expression, 'this')}) AND ({self.sql(expression, 'expression')}))" + + def boolor_sql(self, expression: exp.Boolor) -> str: + return f"(({self.sql(expression, 'this')}) OR ({self.sql(expression, 'expression')}))" + + def order_sql(self, expression: exp.Order, flat: bool = False) -> str: + this = self.sql(expression, "this") + this = f"{this} " if this else this + siblings = "SIBLINGS " if expression.args.get("siblings") else "" + return self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore + + def withfill_sql(self, expression: exp.WithFill) -> str: + from_sql = self.sql(expression, "from_") + from_sql = f" FROM {from_sql}" if from_sql else "" + to_sql = self.sql(expression, "to") + to_sql = f" TO {to_sql}" if to_sql else "" + step_sql = self.sql(expression, "step") + step_sql = f" STEP {step_sql}" if step_sql else "" + interpolated_values = [ + f"{self.sql(e, 'alias')} AS {self.sql(e, 'this')}" + if isinstance(e, exp.Alias) + else self.sql(e, "this") + for e in expression.args.get("interpolate") or [] + ] + interpolate = ( + f" INTERPOLATE ({', '.join(interpolated_values)})" + if interpolated_values + else "" + ) + return f"WITH FILL{from_sql}{to_sql}{step_sql}{interpolate}" + + def cluster_sql(self, expression: exp.Cluster) -> str: + return self.op_expressions("CLUSTER BY", expression) + + def distribute_sql(self, expression: exp.Distribute) -> str: + return self.op_expressions("DISTRIBUTE BY", expression) + + def sort_sql(self, expression: exp.Sort) -> str: + return self.op_expressions("SORT BY", expression) + + def ordered_sql(self, expression: exp.Ordered) -> str: + desc = expression.args.get("desc") + asc = not desc + + nulls_first = expression.args.get("nulls_first") + nulls_last = not nulls_first + nulls_are_large = self.dialect.NULL_ORDERING == "nulls_are_large" + nulls_are_small = self.dialect.NULL_ORDERING == "nulls_are_small" + nulls_are_last = self.dialect.NULL_ORDERING == "nulls_are_last" + + this = self.sql(expression, "this") + + sort_order = " DESC" if desc else (" ASC" if desc is False else "") + nulls_sort_change = "" + if nulls_first and ( + (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last + ): + nulls_sort_change = " NULLS FIRST" + elif ( + nulls_last + and ((asc and nulls_are_small) or (desc and nulls_are_large)) + and not nulls_are_last + ): + nulls_sort_change = " NULLS LAST" + + # If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it + if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: + window = expression.find_ancestor(exp.Window, exp.Select) + if isinstance(window, exp.Window) and window.args.get("spec"): + self.unsupported( + f"'{nulls_sort_change.strip()}' translation not supported in window functions" + ) + nulls_sort_change = "" + elif self.NULL_ORDERING_SUPPORTED is False and ( + (asc and nulls_sort_change == " NULLS LAST") + or (desc and nulls_sort_change == " NULLS FIRST") + ): + # BigQuery does not allow these ordering/nulls combinations when used under + # an aggregation func or under a window containing one + ancestor = expression.find_ancestor(exp.AggFunc, exp.Window, exp.Select) + + if isinstance(ancestor, exp.Window): + ancestor = ancestor.this + if isinstance(ancestor, exp.AggFunc): + self.unsupported( + f"'{nulls_sort_change.strip()}' translation not supported for aggregate functions with {sort_order} sort order" + ) + nulls_sort_change = "" + elif self.NULL_ORDERING_SUPPORTED is None: + if expression.this.is_int: + self.unsupported( + f"'{nulls_sort_change.strip()}' translation not supported with positional ordering" + ) + elif not isinstance(expression.this, exp.Rand): + null_sort_order = ( + " DESC" if nulls_sort_change == " NULLS FIRST" else "" + ) + this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}" + nulls_sort_change = "" + + with_fill = self.sql(expression, "with_fill") + with_fill = f" {with_fill}" if with_fill else "" + + return f"{this}{sort_order}{nulls_sort_change}{with_fill}" + + def matchrecognizemeasure_sql(self, expression: exp.MatchRecognizeMeasure) -> str: + window_frame = self.sql(expression, "window_frame") + window_frame = f"{window_frame} " if window_frame else "" + + this = self.sql(expression, "this") + + return f"{window_frame}{this}" + + def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str: + partition = self.partition_by_sql(expression) + order = self.sql(expression, "order") + measures = self.expressions(expression, key="measures") + measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else "" + rows = self.sql(expression, "rows") + rows = self.seg(rows) if rows else "" + after = self.sql(expression, "after") + after = self.seg(after) if after else "" + pattern = self.sql(expression, "pattern") + pattern = self.seg(f"PATTERN ({pattern})") if pattern else "" + definition_sqls = [ + f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}" + for definition in expression.args.get("define", []) + ] + definitions = self.expressions(sqls=definition_sqls) + define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else "" + body = "".join( + ( + partition, + order, + measures, + rows, + after, + pattern, + define, + ) + ) + alias = self.sql(expression, "alias") + alias = f" {alias}" if alias else "" + return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}" + + def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: + limit = expression.args.get("limit") + + if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): + limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count"))) + elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): + limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression)) + + return csv( + *sqls, + *[self.sql(join) for join in expression.args.get("joins") or []], + self.sql(expression, "match"), + *[self.sql(lateral) for lateral in expression.args.get("laterals") or []], + self.sql(expression, "prewhere"), + self.sql(expression, "where"), + self.sql(expression, "connect"), + self.sql(expression, "group"), + self.sql(expression, "having"), + *[ + gen(self, expression) + for gen in self.AFTER_HAVING_MODIFIER_TRANSFORMS.values() + ], + self.sql(expression, "order"), + *self.offset_limit_modifiers( + expression, isinstance(limit, exp.Fetch), limit + ), + *self.after_limit_modifiers(expression), + self.options_modifier(expression), + self.for_modifiers(expression), + sep="", + ) + + def options_modifier(self, expression: exp.Expression) -> str: + options = self.expressions(expression, key="options") + return f" {options}" if options else "" + + def for_modifiers(self, expression: exp.Expression) -> str: + for_modifiers = self.expressions(expression, key="for_") + return f"{self.sep()}FOR XML{self.seg(for_modifiers)}" if for_modifiers else "" + + def queryoption_sql(self, expression: exp.QueryOption) -> str: + self.unsupported("Unsupported query option.") + return "" + + def offset_limit_modifiers( + self, + expression: exp.Expression, + fetch: bool, + limit: t.Optional[exp.Fetch | exp.Limit], + ) -> t.List[str]: + return [ + self.sql(expression, "offset") if fetch else self.sql(limit), + self.sql(limit) if fetch else self.sql(expression, "offset"), + ] + + def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: + locks = self.expressions(expression, key="locks", sep=" ") + locks = f" {locks}" if locks else "" + return [locks, self.sql(expression, "sample")] + + def select_sql(self, expression: exp.Select) -> str: + into = expression.args.get("into") + if not self.SUPPORTS_SELECT_INTO and into: + into.pop() + + hint = self.sql(expression, "hint") + distinct = self.sql(expression, "distinct") + distinct = f" {distinct}" if distinct else "" + kind = self.sql(expression, "kind") + + limit = expression.args.get("limit") + if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP: + top = self.limit_sql(limit, top=True) + limit.pop() + else: + top = "" + + expressions = self.expressions(expression) + + if kind: + if kind in self.SELECT_KINDS: + kind = f" AS {kind}" + else: + if kind == "STRUCT": + expressions = self.expressions( + sqls=[ + self.sql( + exp.Struct( + expressions=[ + exp.PropertyEQ( + this=e.args.get("alias"), expression=e.this + ) + if isinstance(e, exp.Alias) + else e + for e in expression.expressions + ] + ) + ) + ] + ) + kind = "" + + operation_modifiers = self.expressions( + expression, key="operation_modifiers", sep=" " + ) + operation_modifiers = ( + f"{self.sep()}{operation_modifiers}" if operation_modifiers else "" + ) + + # We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata + # are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first. + top_distinct = ( + f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}" + ) + expressions = f"{self.sep()}{expressions}" if expressions else expressions + sql = self.query_modifiers( + expression, + f"SELECT{top_distinct}{operation_modifiers}{kind}{expressions}", + self.sql(expression, "into", comment=False), + self.sql(expression, "from_", comment=False), + ) + + # If both the CTE and SELECT clauses have comments, generate the latter earlier + if expression.args.get("with_"): + sql = self.maybe_comment(sql, expression) + expression.pop_comments() + + sql = self.prepend_ctes(expression, sql) + + if not self.SUPPORTS_SELECT_INTO and into: + if into.args.get("temporary"): + table_kind = " TEMPORARY" + elif self.SUPPORTS_UNLOGGED_TABLES and into.args.get("unlogged"): + table_kind = " UNLOGGED" + else: + table_kind = "" + sql = f"CREATE{table_kind} TABLE {self.sql(into.this)} AS {sql}" + + return sql + + def schema_sql(self, expression: exp.Schema) -> str: + this = self.sql(expression, "this") + sql = self.schema_columns_sql(expression) + return f"{this} {sql}" if this and sql else this or sql + + def schema_columns_sql(self, expression: exp.Schema) -> str: + if expression.expressions: + return ( + f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + ) + return "" + + def star_sql(self, expression: exp.Star) -> str: + except_ = self.expressions(expression, key="except_", flat=True) + except_ = f"{self.seg(self.STAR_EXCEPT)} ({except_})" if except_ else "" + replace = self.expressions(expression, key="replace", flat=True) + replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" + rename = self.expressions(expression, key="rename", flat=True) + rename = f"{self.seg('RENAME')} ({rename})" if rename else "" + return f"*{except_}{replace}{rename}" + + def parameter_sql(self, expression: exp.Parameter) -> str: + this = self.sql(expression, "this") + return f"{self.PARAMETER_TOKEN}{this}" + + def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: + this = self.sql(expression, "this") + kind = expression.text("kind") + if kind: + kind = f"{kind}." + return f"@@{kind}{this}" + + def placeholder_sql(self, expression: exp.Placeholder) -> str: + return ( + f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" + if expression.this + else "?" + ) + + def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str: + alias = self.sql(expression, "alias") + alias = f"{sep}{alias}" if alias else "" + sample = self.sql(expression, "sample") + if self.dialect.ALIAS_POST_TABLESAMPLE and sample: + alias = f"{sample}{alias}" + + # Set to None so it's not generated again by self.query_modifiers() + expression.set("sample", None) + + pivots = self.expressions(expression, key="pivots", sep="", flat=True) + sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots) + return self.prepend_ctes(expression, sql) + + def qualify_sql(self, expression: exp.Qualify) -> str: + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('QUALIFY')}{self.sep()}{this}" + + def unnest_sql(self, expression: exp.Unnest) -> str: + args = self.expressions(expression, flat=True) + + alias = expression.args.get("alias") + offset = expression.args.get("offset") + + if self.UNNEST_WITH_ORDINALITY: + if alias and isinstance(offset, exp.Expression): + alias.append("columns", offset) + + if alias and self.dialect.UNNEST_COLUMN_ONLY: + columns = alias.columns + alias = self.sql(columns[0]) if columns else "" + else: + alias = self.sql(alias) + + alias = f" AS {alias}" if alias else alias + if self.UNNEST_WITH_ORDINALITY: + suffix = f" WITH ORDINALITY{alias}" if offset else alias + else: + if isinstance(offset, exp.Expression): + suffix = f"{alias} WITH OFFSET AS {self.sql(offset)}" + elif offset: + suffix = f"{alias} WITH OFFSET" + else: + suffix = alias + + return f"UNNEST({args}){suffix}" + + def prewhere_sql(self, expression: exp.PreWhere) -> str: + return "" + + def where_sql(self, expression: exp.Where) -> str: + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('WHERE')}{self.sep()}{this}" + + def window_sql(self, expression: exp.Window) -> str: + this = self.sql(expression, "this") + partition = self.partition_by_sql(expression) + order = expression.args.get("order") + order = self.order_sql(order, flat=True) if order else "" + spec = self.sql(expression, "spec") + alias = self.sql(expression, "alias") + over = self.sql(expression, "over") or "OVER" + + this = f"{this} {'AS' if expression.arg_key == 'windows' else over}" + + first = expression.args.get("first") + if first is None: + first = "" + else: + first = "FIRST" if first else "LAST" + + if not partition and not order and not spec and alias: + return f"{this} {alias}" + + args = self.format_args( + *[arg for arg in (alias, first, partition, order, spec) if arg], sep=" " + ) + return f"{this} ({args})" + + def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str: + partition = self.expressions(expression, key="partition_by", flat=True) + return f"PARTITION BY {partition}" if partition else "" + + def windowspec_sql(self, expression: exp.WindowSpec) -> str: + kind = self.sql(expression, "kind") + start = csv( + self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" " + ) + end = ( + csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") + or "CURRENT ROW" + ) + + window_spec = f"{kind} BETWEEN {start} AND {end}" + + exclude = self.sql(expression, "exclude") + if exclude: + if self.SUPPORTS_WINDOW_EXCLUDE: + window_spec += f" EXCLUDE {exclude}" + else: + self.unsupported("EXCLUDE clause is not supported in the WINDOW clause") + + return window_spec + + def withingroup_sql(self, expression: exp.WithinGroup) -> str: + this = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression")[ + 1: + ] # order has a leading space + return f"{this} WITHIN GROUP ({expression_sql})" + + def between_sql(self, expression: exp.Between) -> str: + this = self.sql(expression, "this") + low = self.sql(expression, "low") + high = self.sql(expression, "high") + symmetric = expression.args.get("symmetric") + + if symmetric and not self.SUPPORTS_BETWEEN_FLAGS: + return ( + f"({this} BETWEEN {low} AND {high} OR {this} BETWEEN {high} AND {low})" + ) + + flag = ( + " SYMMETRIC" + if symmetric + else " ASYMMETRIC" + if symmetric is False and self.SUPPORTS_BETWEEN_FLAGS + else "" # silently drop ASYMMETRIC – semantics identical + ) + return f"{this} BETWEEN{flag} {low} AND {high}" + + def bracket_offset_expressions( + self, expression: exp.Bracket, index_offset: t.Optional[int] = None + ) -> t.List[exp.Expression]: + return apply_index_offset( + expression.this, + expression.expressions, + (index_offset or self.dialect.INDEX_OFFSET) + - expression.args.get("offset", 0), + dialect=self.dialect, + ) + + def bracket_sql(self, expression: exp.Bracket) -> str: + expressions = self.bracket_offset_expressions(expression) + expressions_sql = ", ".join(self.sql(e) for e in expressions) + return f"{self.sql(expression, 'this')}[{expressions_sql}]" + + def all_sql(self, expression: exp.All) -> str: + this = self.sql(expression, "this") + if not isinstance(expression.this, (exp.Tuple, exp.Paren)): + this = self.wrap(this) + return f"ALL {this}" + + def any_sql(self, expression: exp.Any) -> str: + this = self.sql(expression, "this") + if isinstance(expression.this, (*exp.UNWRAPPED_QUERIES, exp.Paren)): + if isinstance(expression.this, exp.UNWRAPPED_QUERIES): + this = self.wrap(this) + return f"ANY{this}" + return f"ANY {this}" + + def exists_sql(self, expression: exp.Exists) -> str: + return f"EXISTS{self.wrap(expression)}" + + def case_sql(self, expression: exp.Case) -> str: + this = self.sql(expression, "this") + statements = [f"CASE {this}" if this else "CASE"] + + for e in expression.args["ifs"]: + statements.append(f"WHEN {self.sql(e, 'this')}") + statements.append(f"THEN {self.sql(e, 'true')}") + + default = self.sql(expression, "default") + + if default: + statements.append(f"ELSE {default}") + + statements.append("END") + + if self.pretty and self.too_wide(statements): + return self.indent("\n".join(statements), skip_first=True, skip_last=True) + + return " ".join(statements) + + def constraint_sql(self, expression: exp.Constraint) -> str: + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + return f"CONSTRAINT {this} {expressions}" + + def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str: + order = expression.args.get("order") + order = f" OVER ({self.order_sql(order, flat=True)})" if order else "" + return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}" + + def extract_sql(self, expression: exp.Extract) -> str: + from bigframes_vendored.sqlglot.dialects.dialect import map_date_part + + this = ( + map_date_part(expression.this, self.dialect) + if self.NORMALIZE_EXTRACT_DATE_PARTS + else expression.this + ) + this_sql = self.sql(this) if self.EXTRACT_ALLOWS_QUOTES else this.name + expression_sql = self.sql(expression, "expression") + + return f"EXTRACT({this_sql} FROM {expression_sql})" + + def trim_sql(self, expression: exp.Trim) -> str: + trim_type = self.sql(expression, "position") + + if trim_type == "LEADING": + func_name = "LTRIM" + elif trim_type == "TRAILING": + func_name = "RTRIM" + else: + func_name = "TRIM" + + return self.func(func_name, expression.this, expression.expression) + + def convert_concat_args( + self, expression: exp.Concat | exp.ConcatWs + ) -> t.List[exp.Expression]: + args = expression.expressions + if isinstance(expression, exp.ConcatWs): + args = args[1:] # Skip the delimiter + + if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"): + args = [exp.cast(e, exp.DataType.Type.TEXT) for e in args] + + if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"): + + def _wrap_with_coalesce(e: exp.Expression) -> exp.Expression: + if not e.type: + from bigframes_vendored.sqlglot.optimizer.annotate_types import ( + annotate_types, + ) + + e = annotate_types(e, dialect=self.dialect) + + if e.is_string or e.is_type(exp.DataType.Type.ARRAY): + return e + + return exp.func("coalesce", e, exp.Literal.string("")) + + args = [_wrap_with_coalesce(e) for e in args] + + return args + + def concat_sql(self, expression: exp.Concat) -> str: + if self.dialect.CONCAT_COALESCE and not expression.args.get("coalesce"): + # Dialect's CONCAT function coalesces NULLs to empty strings, but the expression does not. + # Transpile to double pipe operators, which typically returns NULL if any args are NULL + # instead of coalescing them to empty string. + from bigframes_vendored.sqlglot.dialects.dialect import concat_to_dpipe_sql + + return concat_to_dpipe_sql(self, expression) + + expressions = self.convert_concat_args(expression) + + # Some dialects don't allow a single-argument CONCAT call + if not self.SUPPORTS_SINGLE_ARG_CONCAT and len(expressions) == 1: + return self.sql(expressions[0]) + + return self.func("CONCAT", *expressions) + + def concatws_sql(self, expression: exp.ConcatWs) -> str: + return self.func( + "CONCAT_WS", + seq_get(expression.expressions, 0), + *self.convert_concat_args(expression), + ) + + def check_sql(self, expression: exp.Check) -> str: + this = self.sql(expression, key="this") + return f"CHECK ({this})" + + def foreignkey_sql(self, expression: exp.ForeignKey) -> str: + expressions = self.expressions(expression, flat=True) + expressions = f" ({expressions})" if expressions else "" + reference = self.sql(expression, "reference") + reference = f" {reference}" if reference else "" + delete = self.sql(expression, "delete") + delete = f" ON DELETE {delete}" if delete else "" + update = self.sql(expression, "update") + update = f" ON UPDATE {update}" if update else "" + options = self.expressions(expression, key="options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"FOREIGN KEY{expressions}{reference}{delete}{update}{options}" + + def primarykey_sql(self, expression: exp.PrimaryKey) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + expressions = self.expressions(expression, flat=True) + include = self.sql(expression, "include") + options = self.expressions(expression, key="options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"PRIMARY KEY{this} ({expressions}){include}{options}" + + def if_sql(self, expression: exp.If) -> str: + return self.case_sql( + exp.Case(ifs=[expression], default=expression.args.get("false")) + ) + + def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: + if self.MATCH_AGAINST_TABLE_PREFIX: + expressions = [] + for expr in expression.expressions: + if isinstance(expr, exp.Table): + expressions.append(f"TABLE {self.sql(expr)}") + else: + expressions.append(expr) + else: + expressions = expression.expressions + + modifier = expression.args.get("modifier") + modifier = f" {modifier}" if modifier else "" + return f"{self.func('MATCH', *expressions)} AGAINST({self.sql(expression, 'this')}{modifier})" + + def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str: + return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}" + + def jsonpath_sql(self, expression: exp.JSONPath) -> str: + path = self.expressions(expression, sep="", flat=True).lstrip(".") + + if expression.args.get("escape"): + path = self.escape_str(path) + + if self.QUOTE_JSON_PATH: + path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" + + return path + + def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str: + if isinstance(expression, exp.JSONPathPart): + transform = self.TRANSFORMS.get(expression.__class__) + if not callable(transform): + self.unsupported( + f"Unsupported JSONPathPart type {expression.__class__.__name__}" + ) + return "" + + return transform(self, expression) + + if isinstance(expression, int): + return str(expression) + + if ( + self._quote_json_path_key_using_brackets + and self.JSON_PATH_SINGLE_QUOTE_ESCAPE + ): + escaped = expression.replace("'", "\\'") + escaped = f"\\'{expression}\\'" + else: + escaped = expression.replace('"', '\\"') + escaped = f'"{escaped}"' + + return escaped + + def formatjson_sql(self, expression: exp.FormatJson) -> str: + return f"{self.sql(expression, 'this')} FORMAT JSON" + + def formatphrase_sql(self, expression: exp.FormatPhrase) -> str: + # Output the Teradata column FORMAT override. + # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Data-Types-and-Literals/Data-Type-Formats-and-Format-Phrases/FORMAT + this = self.sql(expression, "this") + fmt = self.sql(expression, "format") + return f"{this} (FORMAT {fmt})" + + def jsonobject_sql(self, expression: exp.JSONObject | exp.JSONObjectAgg) -> str: + null_handling = expression.args.get("null_handling") + null_handling = f" {null_handling}" if null_handling else "" + + unique_keys = expression.args.get("unique_keys") + if unique_keys is not None: + unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS" + else: + unique_keys = "" + + return_type = self.sql(expression, "return_type") + return_type = f" RETURNING {return_type}" if return_type else "" + encoding = self.sql(expression, "encoding") + encoding = f" ENCODING {encoding}" if encoding else "" + + return self.func( + "JSON_OBJECT" + if isinstance(expression, exp.JSONObject) + else "JSON_OBJECTAGG", + *expression.expressions, + suffix=f"{null_handling}{unique_keys}{return_type}{encoding})", + ) + + def jsonobjectagg_sql(self, expression: exp.JSONObjectAgg) -> str: + return self.jsonobject_sql(expression) + + def jsonarray_sql(self, expression: exp.JSONArray) -> str: + null_handling = expression.args.get("null_handling") + null_handling = f" {null_handling}" if null_handling else "" + return_type = self.sql(expression, "return_type") + return_type = f" RETURNING {return_type}" if return_type else "" + strict = " STRICT" if expression.args.get("strict") else "" + return self.func( + "JSON_ARRAY", + *expression.expressions, + suffix=f"{null_handling}{return_type}{strict})", + ) + + def jsonarrayagg_sql(self, expression: exp.JSONArrayAgg) -> str: + this = self.sql(expression, "this") + order = self.sql(expression, "order") + null_handling = expression.args.get("null_handling") + null_handling = f" {null_handling}" if null_handling else "" + return_type = self.sql(expression, "return_type") + return_type = f" RETURNING {return_type}" if return_type else "" + strict = " STRICT" if expression.args.get("strict") else "" + return self.func( + "JSON_ARRAYAGG", + this, + suffix=f"{order}{null_handling}{return_type}{strict})", + ) + + def jsoncolumndef_sql(self, expression: exp.JSONColumnDef) -> str: + path = self.sql(expression, "path") + path = f" PATH {path}" if path else "" + nested_schema = self.sql(expression, "nested_schema") + + if nested_schema: + return f"NESTED{path} {nested_schema}" + + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + + ordinality = " FOR ORDINALITY" if expression.args.get("ordinality") else "" + return f"{this}{kind}{path}{ordinality}" + + def jsonschema_sql(self, expression: exp.JSONSchema) -> str: + return self.func("COLUMNS", *expression.expressions) + + def jsontable_sql(self, expression: exp.JSONTable) -> str: + this = self.sql(expression, "this") + path = self.sql(expression, "path") + path = f", {path}" if path else "" + error_handling = expression.args.get("error_handling") + error_handling = f" {error_handling}" if error_handling else "" + empty_handling = expression.args.get("empty_handling") + empty_handling = f" {empty_handling}" if empty_handling else "" + schema = self.sql(expression, "schema") + return self.func( + "JSON_TABLE", + this, + suffix=f"{path}{error_handling}{empty_handling} {schema})", + ) + + def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str: + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + path = self.sql(expression, "path") + path = f" {path}" if path else "" + as_json = " AS JSON" if expression.args.get("as_json") else "" + return f"{this} {kind}{path}{as_json}" + + def openjson_sql(self, expression: exp.OpenJSON) -> str: + this = self.sql(expression, "this") + path = self.sql(expression, "path") + path = f", {path}" if path else "" + expressions = self.expressions(expression) + with_ = ( + f" WITH ({self.seg(self.indent(expressions), sep='')}{self.seg(')', sep='')}" + if expressions + else "" + ) + return f"OPENJSON({this}{path}){with_}" + + def in_sql(self, expression: exp.In) -> str: + query = expression.args.get("query") + unnest = expression.args.get("unnest") + field = expression.args.get("field") + is_global = " GLOBAL" if expression.args.get("is_global") else "" + + if query: + in_sql = self.sql(query) + elif unnest: + in_sql = self.in_unnest_op(unnest) + elif field: + in_sql = self.sql(field) + else: + in_sql = f"({self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)})" + + return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}" + + def in_unnest_op(self, unnest: exp.Unnest) -> str: + return f"(SELECT {self.sql(unnest)})" + + def interval_sql(self, expression: exp.Interval) -> str: + unit_expression = expression.args.get("unit") + unit = self.sql(unit_expression) if unit_expression else "" + if not self.INTERVAL_ALLOWS_PLURAL_FORM: + unit = self.TIME_PART_SINGULARS.get(unit, unit) + unit = f" {unit}" if unit else "" + + if self.SINGLE_STRING_INTERVAL: + this = expression.this.name if expression.this else "" + if this: + if unit_expression and isinstance(unit_expression, exp.IntervalSpan): + return f"INTERVAL '{this}'{unit}" + return f"INTERVAL '{this}{unit}'" + return f"INTERVAL{unit}" + + this = self.sql(expression, "this") + if this: + unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES) + this = f" {this}" if unwrapped else f" ({this})" + + return f"INTERVAL{this}{unit}" + + def return_sql(self, expression: exp.Return) -> str: + return f"RETURN {self.sql(expression, 'this')}" + + def reference_sql(self, expression: exp.Reference) -> str: + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + expressions = f"({expressions})" if expressions else "" + options = self.expressions(expression, key="options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"REFERENCES {this}{expressions}{options}" + + def anonymous_sql(self, expression: exp.Anonymous) -> str: + # We don't normalize qualified functions such as a.b.foo(), because they can be case-sensitive + parent = expression.parent + is_qualified = isinstance(parent, exp.Dot) and expression is parent.expression + return self.func( + self.sql(expression, "this"), + *expression.expressions, + normalize=not is_qualified, + ) + + def paren_sql(self, expression: exp.Paren) -> str: + sql = self.seg(self.indent(self.sql(expression, "this")), sep="") + return f"({sql}{self.seg(')', sep='')}" + + def neg_sql(self, expression: exp.Neg) -> str: + # This makes sure we don't convert "- - 5" to "--5", which is a comment + this_sql = self.sql(expression, "this") + sep = " " if this_sql[0] == "-" else "" + return f"-{sep}{this_sql}" + + def not_sql(self, expression: exp.Not) -> str: + return f"NOT {self.sql(expression, 'this')}" + + def alias_sql(self, expression: exp.Alias) -> str: + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + return f"{self.sql(expression, 'this')}{alias}" + + def pivotalias_sql(self, expression: exp.PivotAlias) -> str: + alias = expression.args["alias"] + + parent = expression.parent + pivot = parent and parent.parent + + if isinstance(pivot, exp.Pivot) and pivot.unpivot: + identifier_alias = isinstance(alias, exp.Identifier) + literal_alias = isinstance(alias, exp.Literal) + + if identifier_alias and not self.UNPIVOT_ALIASES_ARE_IDENTIFIERS: + alias.replace(exp.Literal.string(alias.output_name)) + elif ( + not identifier_alias + and literal_alias + and self.UNPIVOT_ALIASES_ARE_IDENTIFIERS + ): + alias.replace(exp.to_identifier(alias.output_name)) + + return self.alias_sql(expression) + + def aliases_sql(self, expression: exp.Aliases) -> str: + return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" + + def atindex_sql(self, expression: exp.AtTimeZone) -> str: + this = self.sql(expression, "this") + index = self.sql(expression, "expression") + return f"{this} AT {index}" + + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: + this = self.sql(expression, "this") + zone = self.sql(expression, "zone") + return f"{this} AT TIME ZONE {zone}" + + def fromtimezone_sql(self, expression: exp.FromTimeZone) -> str: + this = self.sql(expression, "this") + zone = self.sql(expression, "zone") + return f"{this} AT TIME ZONE {zone} AT TIME ZONE 'UTC'" + + def add_sql(self, expression: exp.Add) -> str: + return self.binary(expression, "+") + + def and_sql( + self, + expression: exp.And, + stack: t.Optional[t.List[str | exp.Expression]] = None, + ) -> str: + return self.connector_sql(expression, "AND", stack) + + def or_sql( + self, expression: exp.Or, stack: t.Optional[t.List[str | exp.Expression]] = None + ) -> str: + return self.connector_sql(expression, "OR", stack) + + def xor_sql( + self, + expression: exp.Xor, + stack: t.Optional[t.List[str | exp.Expression]] = None, + ) -> str: + return self.connector_sql(expression, "XOR", stack) + + def connector_sql( + self, + expression: exp.Connector, + op: str, + stack: t.Optional[t.List[str | exp.Expression]] = None, + ) -> str: + if stack is not None: + if expression.expressions: + stack.append(self.expressions(expression, sep=f" {op} ")) + else: + stack.append(expression.right) + if expression.comments and self.comments: + for comment in expression.comments: + if comment: + op += f" /*{self.sanitize_comment(comment)}*/" + stack.extend((op, expression.left)) + return op + + stack = [expression] + sqls: t.List[str] = [] + ops = set() + + while stack: + node = stack.pop() + if isinstance(node, exp.Connector): + ops.add(getattr(self, f"{node.key}_sql")(node, stack)) + else: + sql = self.sql(node) + if sqls and sqls[-1] in ops: + sqls[-1] += f" {sql}" + else: + sqls.append(sql) + + sep = "\n" if self.pretty and self.too_wide(sqls) else " " + return sep.join(sqls) + + def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str: + return self.binary(expression, "&") + + def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str: + return self.binary(expression, "<<") + + def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str: + return f"~{self.sql(expression, 'this')}" + + def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str: + return self.binary(expression, "|") + + def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str: + return self.binary(expression, ">>") + + def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str: + return self.binary(expression, "^") + + def cast_sql( + self, expression: exp.Cast, safe_prefix: t.Optional[str] = None + ) -> str: + format_sql = self.sql(expression, "format") + format_sql = f" FORMAT {format_sql}" if format_sql else "" + to_sql = self.sql(expression, "to") + to_sql = f" {to_sql}" if to_sql else "" + action = self.sql(expression, "action") + action = f" {action}" if action else "" + default = self.sql(expression, "default") + default = f" DEFAULT {default} ON CONVERSION ERROR" if default else "" + return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{default}{format_sql}{action})" + + # Base implementation that excludes safe, zone, and target_type metadata args + def strtotime_sql(self, expression: exp.StrToTime) -> str: + return self.func("STR_TO_TIME", expression.this, expression.args.get("format")) + + def currentdate_sql(self, expression: exp.CurrentDate) -> str: + zone = self.sql(expression, "this") + return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" + + def collate_sql(self, expression: exp.Collate) -> str: + if self.COLLATE_IS_FUNC: + return self.function_fallback_sql(expression) + return self.binary(expression, "COLLATE") + + def command_sql(self, expression: exp.Command) -> str: + return f"{self.sql(expression, 'this')} {expression.text('expression').strip()}" + + def comment_sql(self, expression: exp.Comment) -> str: + this = self.sql(expression, "this") + kind = expression.args["kind"] + materialized = " MATERIALIZED" if expression.args.get("materialized") else "" + exists_sql = " IF EXISTS " if expression.args.get("exists") else " " + expression_sql = self.sql(expression, "expression") + return f"COMMENT{exists_sql}ON{materialized} {kind} {this} IS {expression_sql}" + + def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str: + this = self.sql(expression, "this") + delete = " DELETE" if expression.args.get("delete") else "" + recompress = self.sql(expression, "recompress") + recompress = f" RECOMPRESS {recompress}" if recompress else "" + to_disk = self.sql(expression, "to_disk") + to_disk = f" TO DISK {to_disk}" if to_disk else "" + to_volume = self.sql(expression, "to_volume") + to_volume = f" TO VOLUME {to_volume}" if to_volume else "" + return f"{this}{delete}{recompress}{to_disk}{to_volume}" + + def mergetreettl_sql(self, expression: exp.MergeTreeTTL) -> str: + where = self.sql(expression, "where") + group = self.sql(expression, "group") + aggregates = self.expressions(expression, key="aggregates") + aggregates = self.seg("SET") + self.seg(aggregates) if aggregates else "" + + if not (where or group or aggregates) and len(expression.expressions) == 1: + return f"TTL {self.expressions(expression, flat=True)}" + + return f"TTL{self.seg(self.expressions(expression))}{where}{group}{aggregates}" + + def transaction_sql(self, expression: exp.Transaction) -> str: + modes = self.expressions(expression, key="modes") + modes = f" {modes}" if modes else "" + return f"BEGIN{modes}" + + def commit_sql(self, expression: exp.Commit) -> str: + chain = expression.args.get("chain") + if chain is not None: + chain = " AND CHAIN" if chain else " AND NO CHAIN" + + return f"COMMIT{chain or ''}" + + def rollback_sql(self, expression: exp.Rollback) -> str: + savepoint = expression.args.get("savepoint") + savepoint = f" TO {savepoint}" if savepoint else "" + return f"ROLLBACK{savepoint}" + + def altercolumn_sql(self, expression: exp.AlterColumn) -> str: + this = self.sql(expression, "this") + + dtype = self.sql(expression, "dtype") + if dtype: + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + using = self.sql(expression, "using") + using = f" USING {using}" if using else "" + alter_set_type = self.ALTER_SET_TYPE + " " if self.ALTER_SET_TYPE else "" + return f"ALTER COLUMN {this} {alter_set_type}{dtype}{collate}{using}" + + default = self.sql(expression, "default") + if default: + return f"ALTER COLUMN {this} SET DEFAULT {default}" + + comment = self.sql(expression, "comment") + if comment: + return f"ALTER COLUMN {this} COMMENT {comment}" + + visible = expression.args.get("visible") + if visible: + return f"ALTER COLUMN {this} SET {visible}" + + allow_null = expression.args.get("allow_null") + drop = expression.args.get("drop") + + if not drop and not allow_null: + self.unsupported("Unsupported ALTER COLUMN syntax") + + if allow_null is not None: + keyword = "DROP" if drop else "SET" + return f"ALTER COLUMN {this} {keyword} NOT NULL" + + return f"ALTER COLUMN {this} DROP DEFAULT" + + def alterindex_sql(self, expression: exp.AlterIndex) -> str: + this = self.sql(expression, "this") + + visible = expression.args.get("visible") + visible_sql = "VISIBLE" if visible else "INVISIBLE" + + return f"ALTER INDEX {this} {visible_sql}" + + def alterdiststyle_sql(self, expression: exp.AlterDistStyle) -> str: + this = self.sql(expression, "this") + if not isinstance(expression.this, exp.Var): + this = f"KEY DISTKEY {this}" + return f"ALTER DISTSTYLE {this}" + + def altersortkey_sql(self, expression: exp.AlterSortKey) -> str: + compound = " COMPOUND" if expression.args.get("compound") else "" + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + expressions = f"({expressions})" if expressions else "" + return f"ALTER{compound} SORTKEY {this or expressions}" + + def alterrename_sql( + self, expression: exp.AlterRename, include_to: bool = True + ) -> str: + if not self.RENAME_TABLE_WITH_DB: + # Remove db from tables + expression = expression.transform( + lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n + ).assert_is(exp.AlterRename) + this = self.sql(expression, "this") + to_kw = " TO" if include_to else "" + return f"RENAME{to_kw} {this}" + + def renamecolumn_sql(self, expression: exp.RenameColumn) -> str: + exists = " IF EXISTS" if expression.args.get("exists") else "" + old_column = self.sql(expression, "this") + new_column = self.sql(expression, "to") + return f"RENAME COLUMN{exists} {old_column} TO {new_column}" + + def alterset_sql(self, expression: exp.AlterSet) -> str: + exprs = self.expressions(expression, flat=True) + if self.ALTER_SET_WRAPPED: + exprs = f"({exprs})" + + return f"SET {exprs}" + + def alter_sql(self, expression: exp.Alter) -> str: + actions = expression.args["actions"] + + if not self.dialect.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and isinstance( + actions[0], exp.ColumnDef + ): + actions_sql = self.expressions(expression, key="actions", flat=True) + actions_sql = f"ADD {actions_sql}" + else: + actions_list = [] + for action in actions: + if isinstance(action, (exp.ColumnDef, exp.Schema)): + action_sql = self.add_column_sql(action) + else: + action_sql = self.sql(action) + if isinstance(action, exp.Query): + action_sql = f"AS {action_sql}" + + actions_list.append(action_sql) + + actions_sql = self.format_args(*actions_list).lstrip("\n") + + exists = " IF EXISTS" if expression.args.get("exists") else "" + on_cluster = self.sql(expression, "cluster") + on_cluster = f" {on_cluster}" if on_cluster else "" + only = " ONLY" if expression.args.get("only") else "" + options = self.expressions(expression, key="options") + options = f", {options}" if options else "" + kind = self.sql(expression, "kind") + not_valid = " NOT VALID" if expression.args.get("not_valid") else "" + check = " WITH CHECK" if expression.args.get("check") else "" + cascade = ( + " CASCADE" + if expression.args.get("cascade") + and self.dialect.ALTER_TABLE_SUPPORTS_CASCADE + else "" + ) + this = self.sql(expression, "this") + this = f" {this}" if this else "" + + return f"ALTER {kind}{exists}{only}{this}{on_cluster}{check}{self.sep()}{actions_sql}{not_valid}{options}{cascade}" + + def altersession_sql(self, expression: exp.AlterSession) -> str: + items_sql = self.expressions(expression, flat=True) + keyword = "UNSET" if expression.args.get("unset") else "SET" + return f"{keyword} {items_sql}" + + def add_column_sql(self, expression: exp.Expression) -> str: + sql = self.sql(expression) + if isinstance(expression, exp.Schema): + column_text = " COLUMNS" + elif ( + isinstance(expression, exp.ColumnDef) + and self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD + ): + column_text = " COLUMN" + else: + column_text = "" + + return f"ADD{column_text} {sql}" + + def droppartition_sql(self, expression: exp.DropPartition) -> str: + expressions = self.expressions(expression) + exists = " IF EXISTS " if expression.args.get("exists") else " " + return f"DROP{exists}{expressions}" + + def addconstraint_sql(self, expression: exp.AddConstraint) -> str: + return f"ADD {self.expressions(expression, indent=False)}" + + def addpartition_sql(self, expression: exp.AddPartition) -> str: + exists = "IF NOT EXISTS " if expression.args.get("exists") else "" + location = self.sql(expression, "location") + location = f" {location}" if location else "" + return f"ADD {exists}{self.sql(expression.this)}{location}" + + def distinct_sql(self, expression: exp.Distinct) -> str: + this = self.expressions(expression, flat=True) + + if not self.MULTI_ARG_DISTINCT and len(expression.expressions) > 1: + case = exp.case() + for arg in expression.expressions: + case = case.when(arg.is_(exp.null()), exp.null()) + this = self.sql(case.else_(f"({this})")) + + this = f" {this}" if this else "" + + on = self.sql(expression, "on") + on = f" ON {on}" if on else "" + return f"DISTINCT{this}{on}" + + def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str: + return self._embed_ignore_nulls(expression, "IGNORE NULLS") + + def respectnulls_sql(self, expression: exp.RespectNulls) -> str: + return self._embed_ignore_nulls(expression, "RESPECT NULLS") + + def havingmax_sql(self, expression: exp.HavingMax) -> str: + this_sql = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + kind = "MAX" if expression.args.get("max") else "MIN" + return f"{this_sql} HAVING {kind} {expression_sql}" + + def intdiv_sql(self, expression: exp.IntDiv) -> str: + return self.sql( + exp.Cast( + this=exp.Div(this=expression.this, expression=expression.expression), + to=exp.DataType(this=exp.DataType.Type.INT), + ) + ) + + def dpipe_sql(self, expression: exp.DPipe) -> str: + if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"): + return self.func( + "CONCAT", + *(exp.cast(e, exp.DataType.Type.TEXT) for e in expression.flatten()), + ) + return self.binary(expression, "||") + + def div_sql(self, expression: exp.Div) -> str: + l, r = expression.left, expression.right + + if not self.dialect.SAFE_DIVISION and expression.args.get("safe"): + r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0))) + + if self.dialect.TYPED_DIVISION and not expression.args.get("typed"): + if not l.is_type(*exp.DataType.REAL_TYPES) and not r.is_type( + *exp.DataType.REAL_TYPES + ): + l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE)) + + elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"): + if l.is_type(*exp.DataType.INTEGER_TYPES) and r.is_type( + *exp.DataType.INTEGER_TYPES + ): + return self.sql( + exp.cast( + l / r, + to=exp.DataType.Type.BIGINT, + ) + ) + + return self.binary(expression, "/") + + def safedivide_sql(self, expression: exp.SafeDivide) -> str: + n = exp._wrap(expression.this, exp.Binary) + d = exp._wrap(expression.expression, exp.Binary) + return self.sql(exp.If(this=d.neq(0), true=n / d, false=exp.Null())) + + def overlaps_sql(self, expression: exp.Overlaps) -> str: + return self.binary(expression, "OVERLAPS") + + def distance_sql(self, expression: exp.Distance) -> str: + return self.binary(expression, "<->") + + def dot_sql(self, expression: exp.Dot) -> str: + return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" + + def eq_sql(self, expression: exp.EQ) -> str: + return self.binary(expression, "=") + + def propertyeq_sql(self, expression: exp.PropertyEQ) -> str: + return self.binary(expression, ":=") + + def escape_sql(self, expression: exp.Escape) -> str: + return self.binary(expression, "ESCAPE") + + def glob_sql(self, expression: exp.Glob) -> str: + return self.binary(expression, "GLOB") + + def gt_sql(self, expression: exp.GT) -> str: + return self.binary(expression, ">") + + def gte_sql(self, expression: exp.GTE) -> str: + return self.binary(expression, ">=") + + def is_sql(self, expression: exp.Is) -> str: + if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean): + return self.sql( + expression.this + if expression.expression.this + else exp.not_(expression.this) + ) + return self.binary(expression, "IS") + + def _like_sql(self, expression: exp.Like | exp.ILike) -> str: + this = expression.this + rhs = expression.expression + + if isinstance(expression, exp.Like): + exp_class: t.Type[exp.Like | exp.ILike] = exp.Like + op = "LIKE" + else: + exp_class = exp.ILike + op = "ILIKE" + + if isinstance(rhs, (exp.All, exp.Any)) and not self.SUPPORTS_LIKE_QUANTIFIERS: + exprs = rhs.this.unnest() + + if isinstance(exprs, exp.Tuple): + exprs = exprs.expressions + + connective = exp.or_ if isinstance(rhs, exp.Any) else exp.and_ + + like_expr: exp.Expression = exp_class(this=this, expression=exprs[0]) + for expr in exprs[1:]: + like_expr = connective(like_expr, exp_class(this=this, expression=expr)) + + parent = expression.parent + if not isinstance(parent, type(like_expr)) and isinstance( + parent, exp.Condition + ): + like_expr = exp.paren(like_expr, copy=False) + + return self.sql(like_expr) + + return self.binary(expression, op) + + def like_sql(self, expression: exp.Like) -> str: + return self._like_sql(expression) + + def ilike_sql(self, expression: exp.ILike) -> str: + return self._like_sql(expression) + + def match_sql(self, expression: exp.Match) -> str: + return self.binary(expression, "MATCH") + + def similarto_sql(self, expression: exp.SimilarTo) -> str: + return self.binary(expression, "SIMILAR TO") + + def lt_sql(self, expression: exp.LT) -> str: + return self.binary(expression, "<") + + def lte_sql(self, expression: exp.LTE) -> str: + return self.binary(expression, "<=") + + def mod_sql(self, expression: exp.Mod) -> str: + return self.binary(expression, "%") + + def mul_sql(self, expression: exp.Mul) -> str: + return self.binary(expression, "*") + + def neq_sql(self, expression: exp.NEQ) -> str: + return self.binary(expression, "<>") + + def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str: + return self.binary(expression, "IS NOT DISTINCT FROM") + + def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str: + return self.binary(expression, "IS DISTINCT FROM") + + def sub_sql(self, expression: exp.Sub) -> str: + return self.binary(expression, "-") + + def trycast_sql(self, expression: exp.TryCast) -> str: + return self.cast_sql(expression, safe_prefix="TRY_") + + def jsoncast_sql(self, expression: exp.JSONCast) -> str: + return self.cast_sql(expression) + + def try_sql(self, expression: exp.Try) -> str: + if not self.TRY_SUPPORTED: + self.unsupported("Unsupported TRY function") + return self.sql(expression, "this") + + return self.func("TRY", expression.this) + + def log_sql(self, expression: exp.Log) -> str: + this = expression.this + expr = expression.expression + + if self.dialect.LOG_BASE_FIRST is False: + this, expr = expr, this + elif self.dialect.LOG_BASE_FIRST is None and expr: + if this.name in ("2", "10"): + return self.func(f"LOG{this.name}", expr) + + self.unsupported(f"Unsupported logarithm with base {self.sql(this)}") + + return self.func("LOG", this, expr) + + def use_sql(self, expression: exp.Use) -> str: + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + this = self.sql(expression, "this") or self.expressions(expression, flat=True) + this = f" {this}" if this else "" + return f"USE{kind}{this}" + + def binary(self, expression: exp.Binary, op: str) -> str: + sqls: t.List[str] = [] + stack: t.List[t.Union[str, exp.Expression]] = [expression] + binary_type = type(expression) + + while stack: + node = stack.pop() + + if type(node) is binary_type: + op_func = node.args.get("operator") + if op_func: + op = f"OPERATOR({self.sql(op_func)})" + + stack.append(node.right) + stack.append(f" {self.maybe_comment(op, comments=node.comments)} ") + stack.append(node.left) + else: + sqls.append(self.sql(node)) + + return "".join(sqls) + + def ceil_floor(self, expression: exp.Ceil | exp.Floor) -> str: + to_clause = self.sql(expression, "to") + if to_clause: + return f"{expression.sql_name()}({self.sql(expression, 'this')} TO {to_clause})" + + return self.function_fallback_sql(expression) + + def function_fallback_sql(self, expression: exp.Func) -> str: + args = [] + + for key in expression.arg_types: + arg_value = expression.args.get(key) + + if isinstance(arg_value, list): + for value in arg_value: + args.append(value) + elif arg_value is not None: + args.append(arg_value) + + if self.dialect.PRESERVE_ORIGINAL_NAMES: + name = ( + expression._meta and expression.meta.get("name") + ) or expression.sql_name() + else: + name = expression.sql_name() + + return self.func(name, *args) + + def func( + self, + name: str, + *args: t.Optional[exp.Expression | str], + prefix: str = "(", + suffix: str = ")", + normalize: bool = True, + ) -> str: + name = self.normalize_func(name) if normalize else name + return f"{name}{prefix}{self.format_args(*args)}{suffix}" + + def format_args( + self, *args: t.Optional[str | exp.Expression], sep: str = ", " + ) -> str: + arg_sqls = tuple( + self.sql(arg) + for arg in args + if arg is not None and not isinstance(arg, bool) + ) + if self.pretty and self.too_wide(arg_sqls): + return self.indent( + "\n" + f"{sep.strip()}\n".join(arg_sqls) + "\n", + skip_first=True, + skip_last=True, + ) + return sep.join(arg_sqls) + + def too_wide(self, args: t.Iterable) -> bool: + return sum(len(arg) for arg in args) > self.max_text_width + + def format_time( + self, + expression: exp.Expression, + inverse_time_mapping: t.Optional[t.Dict[str, str]] = None, + inverse_time_trie: t.Optional[t.Dict] = None, + ) -> t.Optional[str]: + return format_time( + self.sql(expression, "format"), + inverse_time_mapping or self.dialect.INVERSE_TIME_MAPPING, + inverse_time_trie or self.dialect.INVERSE_TIME_TRIE, + ) + + def expressions( + self, + expression: t.Optional[exp.Expression] = None, + key: t.Optional[str] = None, + sqls: t.Optional[t.Collection[str | exp.Expression]] = None, + flat: bool = False, + indent: bool = True, + skip_first: bool = False, + skip_last: bool = False, + sep: str = ", ", + prefix: str = "", + dynamic: bool = False, + new_line: bool = False, + ) -> str: + expressions = expression.args.get(key or "expressions") if expression else sqls + + if not expressions: + return "" + + if flat: + return sep.join(sql for sql in (self.sql(e) for e in expressions) if sql) + + num_sqls = len(expressions) + result_sqls = [] + + for i, e in enumerate(expressions): + sql = self.sql(e, comment=False) + if not sql: + continue + + comments = ( + self.maybe_comment("", e) if isinstance(e, exp.Expression) else "" + ) + + if self.pretty: + if self.leading_comma: + result_sqls.append(f"{sep if i > 0 else ''}{prefix}{sql}{comments}") + else: + result_sqls.append( + f"{prefix}{sql}{(sep.rstrip() if comments else sep) if i + 1 < num_sqls else ''}{comments}" + ) + else: + result_sqls.append( + f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}" + ) + + if self.pretty and (not dynamic or self.too_wide(result_sqls)): + if new_line: + result_sqls.insert(0, "") + result_sqls.append("") + result_sql = "\n".join(s.rstrip() for s in result_sqls) + else: + result_sql = "".join(result_sqls) + + return ( + self.indent(result_sql, skip_first=skip_first, skip_last=skip_last) + if indent + else result_sql + ) + + def op_expressions( + self, op: str, expression: exp.Expression, flat: bool = False + ) -> str: + flat = flat or isinstance(expression.parent, exp.Properties) + expressions_sql = self.expressions(expression, flat=flat) + if flat: + return f"{op} {expressions_sql}" + return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" + + def naked_property(self, expression: exp.Property) -> str: + property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__) + if not property_name: + self.unsupported(f"Unsupported property {expression.__class__.__name__}") + return f"{property_name} {self.sql(expression, 'this')}" + + def tag_sql(self, expression: exp.Tag) -> str: + return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}" + + def token_sql(self, token_type: TokenType) -> str: + return self.TOKEN_MAPPING.get(token_type, token_type.name) + + def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: + this = self.sql(expression, "this") + expressions = self.no_identify(self.expressions, expression) + expressions = ( + self.wrap(expressions) + if expression.args.get("wrapped") + else f" {expressions}" + ) + return f"{this}{expressions}" if expressions.strip() != "" else this + + def joinhint_sql(self, expression: exp.JoinHint) -> str: + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + return f"{this}({expressions})" + + def kwarg_sql(self, expression: exp.Kwarg) -> str: + return self.binary(expression, "=>") + + def when_sql(self, expression: exp.When) -> str: + matched = "MATCHED" if expression.args["matched"] else "NOT MATCHED" + source = ( + " BY SOURCE" + if self.MATCHED_BY_SOURCE and expression.args.get("source") + else "" + ) + condition = self.sql(expression, "condition") + condition = f" AND {condition}" if condition else "" + + then_expression = expression.args.get("then") + if isinstance(then_expression, exp.Insert): + this = self.sql(then_expression, "this") + this = f"INSERT {this}" if this else "INSERT" + then = self.sql(then_expression, "expression") + then = f"{this} VALUES {then}" if then else this + elif isinstance(then_expression, exp.Update): + if isinstance(then_expression.args.get("expressions"), exp.Star): + then = f"UPDATE {self.sql(then_expression, 'expressions')}" + else: + expressions_sql = self.expressions(then_expression) + then = ( + f"UPDATE SET{self.sep()}{expressions_sql}" + if expressions_sql + else "UPDATE" + ) + + else: + then = self.sql(then_expression) + return f"WHEN {matched}{source}{condition} THEN {then}" + + def whens_sql(self, expression: exp.Whens) -> str: + return self.expressions(expression, sep=" ", indent=False) + + def merge_sql(self, expression: exp.Merge) -> str: + table = expression.this + table_alias = "" + + hints = table.args.get("hints") + if hints and table.alias and isinstance(hints[0], exp.WithTableHint): + # T-SQL syntax is MERGE ... [WITH ()] [[AS] table_alias] + table_alias = f" AS {self.sql(table.args['alias'].pop())}" + + this = self.sql(table) + using = f"USING {self.sql(expression, 'using')}" + whens = self.sql(expression, "whens") + + on = self.sql(expression, "on") + on = f"ON {on}" if on else "" + + if not on: + on = self.expressions(expression, key="using_cond") + on = f"USING ({on})" if on else "" + + returning = self.sql(expression, "returning") + if returning: + whens = f"{whens}{returning}" + + sep = self.sep() + + return self.prepend_ctes( + expression, + f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{whens}", + ) + + @unsupported_args("format") + def tochar_sql(self, expression: exp.ToChar) -> str: + return self.sql(exp.cast(expression.this, exp.DataType.Type.TEXT)) + + def tonumber_sql(self, expression: exp.ToNumber) -> str: + if not self.SUPPORTS_TO_NUMBER: + self.unsupported("Unsupported TO_NUMBER function") + return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE)) + + fmt = expression.args.get("format") + if not fmt: + self.unsupported("Conversion format is required for TO_NUMBER") + return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE)) + + return self.func("TO_NUMBER", expression.this, fmt) + + def dictproperty_sql(self, expression: exp.DictProperty) -> str: + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + settings_sql = self.expressions(expression, key="settings", sep=" ") + args = ( + f"({self.sep('')}{settings_sql}{self.seg(')', sep='')}" + if settings_sql + else "()" + ) + return f"{this}({kind}{args})" + + def dictrange_sql(self, expression: exp.DictRange) -> str: + this = self.sql(expression, "this") + max = self.sql(expression, "max") + min = self.sql(expression, "min") + return f"{this}(MIN {min} MAX {max})" + + def dictsubproperty_sql(self, expression: exp.DictSubProperty) -> str: + return f"{self.sql(expression, 'this')} {self.sql(expression, 'value')}" + + def duplicatekeyproperty_sql(self, expression: exp.DuplicateKeyProperty) -> str: + return f"DUPLICATE KEY ({self.expressions(expression, flat=True)})" + + # https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/ + def uniquekeyproperty_sql( + self, expression: exp.UniqueKeyProperty, prefix: str = "UNIQUE KEY" + ) -> str: + return f"{prefix} ({self.expressions(expression, flat=True)})" + + # https://docs.starrocks.io/docs/sql-reference/sql-statements/data-definition/CREATE_TABLE/#distribution_desc + def distributedbyproperty_sql(self, expression: exp.DistributedByProperty) -> str: + expressions = self.expressions(expression, flat=True) + expressions = f" {self.wrap(expressions)}" if expressions else "" + buckets = self.sql(expression, "buckets") + kind = self.sql(expression, "kind") + buckets = f" BUCKETS {buckets}" if buckets else "" + order = self.sql(expression, "order") + return f"DISTRIBUTED BY {kind}{expressions}{buckets}{order}" + + def oncluster_sql(self, expression: exp.OnCluster) -> str: + return "" + + def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str: + expressions = self.expressions(expression, key="expressions", flat=True) + sorted_by = self.expressions(expression, key="sorted_by", flat=True) + sorted_by = f" SORTED BY ({sorted_by})" if sorted_by else "" + buckets = self.sql(expression, "buckets") + return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS" + + def anyvalue_sql(self, expression: exp.AnyValue) -> str: + this = self.sql(expression, "this") + having = self.sql(expression, "having") + + if having: + this = f"{this} HAVING {'MAX' if expression.args.get('max') else 'MIN'} {having}" + + return self.func("ANY_VALUE", this) + + def querytransform_sql(self, expression: exp.QueryTransform) -> str: + transform = self.func("TRANSFORM", *expression.expressions) + row_format_before = self.sql(expression, "row_format_before") + row_format_before = f" {row_format_before}" if row_format_before else "" + record_writer = self.sql(expression, "record_writer") + record_writer = f" RECORDWRITER {record_writer}" if record_writer else "" + using = f" USING {self.sql(expression, 'command_script')}" + schema = self.sql(expression, "schema") + schema = f" AS {schema}" if schema else "" + row_format_after = self.sql(expression, "row_format_after") + row_format_after = f" {row_format_after}" if row_format_after else "" + record_reader = self.sql(expression, "record_reader") + record_reader = f" RECORDREADER {record_reader}" if record_reader else "" + return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}" + + def indexconstraintoption_sql(self, expression: exp.IndexConstraintOption) -> str: + key_block_size = self.sql(expression, "key_block_size") + if key_block_size: + return f"KEY_BLOCK_SIZE = {key_block_size}" + + using = self.sql(expression, "using") + if using: + return f"USING {using}" + + parser = self.sql(expression, "parser") + if parser: + return f"WITH PARSER {parser}" + + comment = self.sql(expression, "comment") + if comment: + return f"COMMENT {comment}" + + visible = expression.args.get("visible") + if visible is not None: + return "VISIBLE" if visible else "INVISIBLE" + + engine_attr = self.sql(expression, "engine_attr") + if engine_attr: + return f"ENGINE_ATTRIBUTE = {engine_attr}" + + secondary_engine_attr = self.sql(expression, "secondary_engine_attr") + if secondary_engine_attr: + return f"SECONDARY_ENGINE_ATTRIBUTE = {secondary_engine_attr}" + + self.unsupported("Unsupported index constraint option.") + return "" + + def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str: + enforced = " ENFORCED" if expression.args.get("enforced") else "" + return f"CHECK ({self.sql(expression, 'this')}){enforced}" + + def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str: + kind = self.sql(expression, "kind") + kind = f"{kind} INDEX" if kind else "INDEX" + this = self.sql(expression, "this") + this = f" {this}" if this else "" + index_type = self.sql(expression, "index_type") + index_type = f" USING {index_type}" if index_type else "" + expressions = self.expressions(expression, flat=True) + expressions = f" ({expressions})" if expressions else "" + options = self.expressions(expression, key="options", sep=" ") + options = f" {options}" if options else "" + return f"{kind}{this}{index_type}{expressions}{options}" + + def nvl2_sql(self, expression: exp.Nvl2) -> str: + if self.NVL2_SUPPORTED: + return self.function_fallback_sql(expression) + + case = exp.Case().when( + expression.this.is_(exp.null()).not_(copy=False), + expression.args["true"], + copy=False, + ) + else_cond = expression.args.get("false") + if else_cond: + case.else_(else_cond, copy=False) + + return self.sql(case) + + def comprehension_sql(self, expression: exp.Comprehension) -> str: + this = self.sql(expression, "this") + expr = self.sql(expression, "expression") + position = self.sql(expression, "position") + position = f", {position}" if position else "" + iterator = self.sql(expression, "iterator") + condition = self.sql(expression, "condition") + condition = f" IF {condition}" if condition else "" + return f"{this} FOR {expr}{position} IN {iterator}{condition}" + + def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str: + return f"{self.sql(expression, 'this')}({self.sql(expression, 'expression')})" + + def opclass_sql(self, expression: exp.Opclass) -> str: + return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + + def _ml_sql(self, expression: exp.Func, name: str) -> str: + model = self.sql(expression, "this") + model = f"MODEL {model}" + expr = expression.expression + if expr: + expr_sql = self.sql(expression, "expression") + expr_sql = ( + f"TABLE {expr_sql}" if not isinstance(expr, exp.Subquery) else expr_sql + ) + else: + expr_sql = None + + parameters = self.sql(expression, "params_struct") or None + + return self.func(name, model, expr_sql, parameters) + + def predict_sql(self, expression: exp.Predict) -> str: + return self._ml_sql(expression, "PREDICT") + + def generateembedding_sql(self, expression: exp.GenerateEmbedding) -> str: + name = ( + "GENERATE_TEXT_EMBEDDING" + if expression.args.get("is_text") + else "GENERATE_EMBEDDING" + ) + return self._ml_sql(expression, name) + + def mltranslate_sql(self, expression: exp.MLTranslate) -> str: + return self._ml_sql(expression, "TRANSLATE") + + def mlforecast_sql(self, expression: exp.MLForecast) -> str: + return self._ml_sql(expression, "FORECAST") + + def featuresattime_sql(self, expression: exp.FeaturesAtTime) -> str: + this_sql = self.sql(expression, "this") + if isinstance(expression.this, exp.Table): + this_sql = f"TABLE {this_sql}" + + return self.func( + "FEATURES_AT_TIME", + this_sql, + expression.args.get("time"), + expression.args.get("num_rows"), + expression.args.get("ignore_feature_nulls"), + ) + + def vectorsearch_sql(self, expression: exp.VectorSearch) -> str: + this_sql = self.sql(expression, "this") + if isinstance(expression.this, exp.Table): + this_sql = f"TABLE {this_sql}" + + query_table = self.sql(expression, "query_table") + if isinstance(expression.args["query_table"], exp.Table): + query_table = f"TABLE {query_table}" + + return self.func( + "VECTOR_SEARCH", + this_sql, + expression.args.get("column_to_search"), + query_table, + expression.args.get("query_column_to_search"), + expression.args.get("top_k"), + expression.args.get("distance_type"), + expression.args.get("options"), + ) + + def forin_sql(self, expression: exp.ForIn) -> str: + this = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + return f"FOR {this} DO {expression_sql}" + + def refresh_sql(self, expression: exp.Refresh) -> str: + this = self.sql(expression, "this") + kind = ( + "" + if isinstance(expression.this, exp.Literal) + else f"{expression.text('kind')} " + ) + return f"REFRESH {kind}{this}" + + def toarray_sql(self, expression: exp.ToArray) -> str: + arg = expression.this + if not arg.type: + from bigframes_vendored.sqlglot.optimizer.annotate_types import ( + annotate_types, + ) + + arg = annotate_types(arg, dialect=self.dialect) + + if arg.is_type(exp.DataType.Type.ARRAY): + return self.sql(arg) + + cond_for_null = arg.is_(exp.null()) + return self.sql( + exp.func("IF", cond_for_null, exp.null(), exp.array(arg, copy=False)) + ) + + def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str: + this = expression.this + time_format = self.format_time(expression) + + if time_format: + return self.sql( + exp.cast( + exp.StrToTime(this=this, format=expression.args["format"]), + exp.DataType.Type.TIME, + ) + ) + + if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME): + return self.sql(this) + + return self.sql(exp.cast(this, exp.DataType.Type.TIME)) + + def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str: + this = expression.this + if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type( + exp.DataType.Type.TIMESTAMP + ): + return self.sql(this) + + return self.sql( + exp.cast(this, exp.DataType.Type.TIMESTAMP, dialect=self.dialect) + ) + + def tsordstodatetime_sql(self, expression: exp.TsOrDsToDatetime) -> str: + this = expression.this + if isinstance(this, exp.TsOrDsToDatetime) or this.is_type( + exp.DataType.Type.DATETIME + ): + return self.sql(this) + + return self.sql( + exp.cast(this, exp.DataType.Type.DATETIME, dialect=self.dialect) + ) + + def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str: + this = expression.this + time_format = self.format_time(expression) + + if time_format and time_format not in ( + self.dialect.TIME_FORMAT, + self.dialect.DATE_FORMAT, + ): + return self.sql( + exp.cast( + exp.StrToTime(this=this, format=expression.args["format"]), + exp.DataType.Type.DATE, + ) + ) + + if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE): + return self.sql(this) + + return self.sql(exp.cast(this, exp.DataType.Type.DATE)) + + def unixdate_sql(self, expression: exp.UnixDate) -> str: + return self.sql( + exp.func( + "DATEDIFF", + expression.this, + exp.cast(exp.Literal.string("1970-01-01"), exp.DataType.Type.DATE), + "day", + ) + ) + + def lastday_sql(self, expression: exp.LastDay) -> str: + if self.LAST_DAY_SUPPORTS_DATE_PART: + return self.function_fallback_sql(expression) + + unit = expression.text("unit") + if unit and unit != "MONTH": + self.unsupported("Date parts are not supported in LAST_DAY.") + + return self.func("LAST_DAY", expression.this) + + def dateadd_sql(self, expression: exp.DateAdd) -> str: + from bigframes_vendored.sqlglot.dialects.dialect import unit_to_str + + return self.func( + "DATE_ADD", expression.this, expression.expression, unit_to_str(expression) + ) + + def arrayany_sql(self, expression: exp.ArrayAny) -> str: + if self.CAN_IMPLEMENT_ARRAY_ANY: + filtered = exp.ArrayFilter( + this=expression.this, expression=expression.expression + ) + filtered_not_empty = exp.ArraySize(this=filtered).neq(0) + original_is_empty = exp.ArraySize(this=expression.this).eq(0) + return self.sql(exp.paren(original_is_empty.or_(filtered_not_empty))) + + from bigframes_vendored.sqlglot.dialects import Dialect + + # SQLGlot's executor supports ARRAY_ANY, so we don't wanna warn for the SQLGlot dialect + if self.dialect.__class__ != Dialect: + self.unsupported("ARRAY_ANY is unsupported") + + return self.function_fallback_sql(expression) + + def struct_sql(self, expression: exp.Struct) -> str: + expression.set( + "expressions", + [ + exp.alias_(e.expression, e.name if e.this.is_string else e.this) + if isinstance(e, exp.PropertyEQ) + else e + for e in expression.expressions + ], + ) + + return self.function_fallback_sql(expression) + + def partitionrange_sql(self, expression: exp.PartitionRange) -> str: + low = self.sql(expression, "this") + high = self.sql(expression, "expression") + + return f"{low} TO {high}" + + def truncatetable_sql(self, expression: exp.TruncateTable) -> str: + target = "DATABASE" if expression.args.get("is_database") else "TABLE" + tables = f" {self.expressions(expression)}" + + exists = " IF EXISTS" if expression.args.get("exists") else "" + + on_cluster = self.sql(expression, "cluster") + on_cluster = f" {on_cluster}" if on_cluster else "" + + identity = self.sql(expression, "identity") + identity = f" {identity} IDENTITY" if identity else "" + + option = self.sql(expression, "option") + option = f" {option}" if option else "" + + partition = self.sql(expression, "partition") + partition = f" {partition}" if partition else "" + + return f"TRUNCATE {target}{exists}{tables}{on_cluster}{identity}{option}{partition}" + + # This transpiles T-SQL's CONVERT function + # https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16 + def convert_sql(self, expression: exp.Convert) -> str: + to = expression.this + value = expression.expression + style = expression.args.get("style") + safe = expression.args.get("safe") + strict = expression.args.get("strict") + + if not to or not value: + return "" + + # Retrieve length of datatype and override to default if not specified + if ( + not seq_get(to.expressions, 0) + and to.this in self.PARAMETERIZABLE_TEXT_TYPES + ): + to = exp.DataType.build( + to.this, expressions=[exp.Literal.number(30)], nested=False + ) + + transformed: t.Optional[exp.Expression] = None + cast = exp.Cast if strict else exp.TryCast + + # Check whether a conversion with format (T-SQL calls this 'style') is applicable + if isinstance(style, exp.Literal) and style.is_int: + from bigframes_vendored.sqlglot.dialects.tsql import TSQL + + style_value = style.name + converted_style = TSQL.CONVERT_FORMAT_MAPPING.get(style_value) + if not converted_style: + self.unsupported(f"Unsupported T-SQL 'style' value: {style_value}") + + fmt = exp.Literal.string(converted_style) + + if to.this == exp.DataType.Type.DATE: + transformed = exp.StrToDate(this=value, format=fmt) + elif to.this in (exp.DataType.Type.DATETIME, exp.DataType.Type.DATETIME2): + transformed = exp.StrToTime(this=value, format=fmt) + elif to.this in self.PARAMETERIZABLE_TEXT_TYPES: + transformed = cast( + this=exp.TimeToStr(this=value, format=fmt), to=to, safe=safe + ) + elif to.this == exp.DataType.Type.TEXT: + transformed = exp.TimeToStr(this=value, format=fmt) + + if not transformed: + transformed = cast(this=value, to=to, safe=safe) + + return self.sql(transformed) + + def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: + this = expression.this + if isinstance(this, exp.JSONPathWildcard): + this = self.json_path_part(this) + return f".{this}" if this else "" + + if self.SAFE_JSON_PATH_KEY_RE.match(this): + return f".{this}" + + this = self.json_path_part(this) + return ( + f"[{this}]" + if self._quote_json_path_key_using_brackets + and self.JSON_PATH_BRACKETED_KEY_SUPPORTED + else f".{this}" + ) + + def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str: + this = self.json_path_part(expression.this) + return f"[{this}]" if this else "" + + def _simplify_unless_literal(self, expression: E) -> E: + if not isinstance(expression, exp.Literal): + from bigframes_vendored.sqlglot.optimizer.simplify import simplify + + expression = simplify(expression, dialect=self.dialect) + + return expression + + def _embed_ignore_nulls( + self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str + ) -> str: + this = expression.this + if isinstance(this, self.RESPECT_IGNORE_NULLS_UNSUPPORTED_EXPRESSIONS): + self.unsupported( + f"RESPECT/IGNORE NULLS is not supported for {type(this).key} in {self.dialect.__class__.__name__}" + ) + return self.sql(this) + + if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"): + # The first modifier here will be the one closest to the AggFunc's arg + mods = sorted( + expression.find_all(exp.HavingMax, exp.Order, exp.Limit), + key=lambda x: 0 + if isinstance(x, exp.HavingMax) + else (1 if isinstance(x, exp.Order) else 2), + ) + + if mods: + mod = mods[0] + this = expression.__class__(this=mod.this.copy()) + this.meta["inline"] = True + mod.this.replace(this) + return self.sql(expression.this) + + agg_func = expression.find(exp.AggFunc) + + if agg_func: + agg_func_sql = self.sql(agg_func, comment=False)[:-1] + f" {text})" + return self.maybe_comment(agg_func_sql, comments=agg_func.comments) + + return f"{self.sql(expression, 'this')} {text}" + + def _replace_line_breaks(self, string: str) -> str: + """We don't want to extra indent line breaks so we temporarily replace them with sentinels.""" + if self.pretty: + return string.replace("\n", self.SENTINEL_LINE_BREAK) + return string + + def copyparameter_sql(self, expression: exp.CopyParameter) -> str: + option = self.sql(expression, "this") + + if expression.expressions: + upper = option.upper() + + # Snowflake FILE_FORMAT options are separated by whitespace + sep = " " if upper == "FILE_FORMAT" else ", " + + # Databricks copy/format options do not set their list of values with EQ + op = " " if upper in ("COPY_OPTIONS", "FORMAT_OPTIONS") else " = " + values = self.expressions(expression, flat=True, sep=sep) + return f"{option}{op}({values})" + + value = self.sql(expression, "expression") + + if not value: + return option + + op = " = " if self.COPY_PARAMS_EQ_REQUIRED else " " + + return f"{option}{op}{value}" + + def credentials_sql(self, expression: exp.Credentials) -> str: + cred_expr = expression.args.get("credentials") + if isinstance(cred_expr, exp.Literal): + # Redshift case: CREDENTIALS + credentials = self.sql(expression, "credentials") + credentials = f"CREDENTIALS {credentials}" if credentials else "" + else: + # Snowflake case: CREDENTIALS = (...) + credentials = self.expressions( + expression, key="credentials", flat=True, sep=" " + ) + credentials = ( + f"CREDENTIALS = ({credentials})" if cred_expr is not None else "" + ) + + storage = self.sql(expression, "storage") + storage = f"STORAGE_INTEGRATION = {storage}" if storage else "" + + encryption = self.expressions(expression, key="encryption", flat=True, sep=" ") + encryption = f" ENCRYPTION = ({encryption})" if encryption else "" + + iam_role = self.sql(expression, "iam_role") + iam_role = f"IAM_ROLE {iam_role}" if iam_role else "" + + region = self.sql(expression, "region") + region = f" REGION {region}" if region else "" + + return f"{credentials}{storage}{encryption}{iam_role}{region}" + + def copy_sql(self, expression: exp.Copy) -> str: + this = self.sql(expression, "this") + this = f" INTO {this}" if self.COPY_HAS_INTO_KEYWORD else f" {this}" + + credentials = self.sql(expression, "credentials") + credentials = self.seg(credentials) if credentials else "" + files = self.expressions(expression, key="files", flat=True) + kind = ( + self.seg("FROM" if expression.args.get("kind") else "TO") if files else "" + ) + + sep = ", " if self.dialect.COPY_PARAMS_ARE_CSV else " " + params = self.expressions( + expression, + key="params", + sep=sep, + new_line=True, + skip_last=True, + skip_first=True, + indent=self.COPY_PARAMS_ARE_WRAPPED, + ) + + if params: + if self.COPY_PARAMS_ARE_WRAPPED: + params = f" WITH ({params})" + elif not self.pretty and (files or credentials): + params = f" {params}" + + return f"COPY{this}{kind} {files}{credentials}{params}" + + def semicolon_sql(self, expression: exp.Semicolon) -> str: + return "" + + def datadeletionproperty_sql(self, expression: exp.DataDeletionProperty) -> str: + on_sql = "ON" if expression.args.get("on") else "OFF" + filter_col: t.Optional[str] = self.sql(expression, "filter_column") + filter_col = f"FILTER_COLUMN={filter_col}" if filter_col else None + retention_period: t.Optional[str] = self.sql(expression, "retention_period") + retention_period = ( + f"RETENTION_PERIOD={retention_period}" if retention_period else None + ) + + if filter_col or retention_period: + on_sql = self.func("ON", filter_col, retention_period) + + return f"DATA_DELETION={on_sql}" + + def maskingpolicycolumnconstraint_sql( + self, expression: exp.MaskingPolicyColumnConstraint + ) -> str: + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + expressions = f" USING ({expressions})" if expressions else "" + return f"MASKING POLICY {this}{expressions}" + + def gapfill_sql(self, expression: exp.GapFill) -> str: + this = self.sql(expression, "this") + this = f"TABLE {this}" + return self.func( + "GAP_FILL", this, *[v for k, v in expression.args.items() if k != "this"] + ) + + def scope_resolution(self, rhs: str, scope_name: str) -> str: + return self.func("SCOPE_RESOLUTION", scope_name or None, rhs) + + def scoperesolution_sql(self, expression: exp.ScopeResolution) -> str: + this = self.sql(expression, "this") + expr = expression.expression + + if isinstance(expr, exp.Func): + # T-SQL's CLR functions are case sensitive + expr = f"{self.sql(expr, 'this')}({self.format_args(*expr.expressions)})" + else: + expr = self.sql(expression, "expression") + + return self.scope_resolution(expr, this) + + def parsejson_sql(self, expression: exp.ParseJSON) -> str: + if self.PARSE_JSON_NAME is None: + return self.sql(expression.this) + + return self.func(self.PARSE_JSON_NAME, expression.this, expression.expression) + + def rand_sql(self, expression: exp.Rand) -> str: + lower = self.sql(expression, "lower") + upper = self.sql(expression, "upper") + + if lower and upper: + return ( + f"({upper} - {lower}) * {self.func('RAND', expression.this)} + {lower}" + ) + return self.func("RAND", expression.this) + + def changes_sql(self, expression: exp.Changes) -> str: + information = self.sql(expression, "information") + information = f"INFORMATION => {information}" + at_before = self.sql(expression, "at_before") + at_before = f"{self.seg('')}{at_before}" if at_before else "" + end = self.sql(expression, "end") + end = f"{self.seg('')}{end}" if end else "" + + return f"CHANGES ({information}){at_before}{end}" + + def pad_sql(self, expression: exp.Pad) -> str: + prefix = "L" if expression.args.get("is_left") else "R" + + fill_pattern = self.sql(expression, "fill_pattern") or None + if not fill_pattern and self.PAD_FILL_PATTERN_IS_REQUIRED: + fill_pattern = "' '" + + return self.func( + f"{prefix}PAD", expression.this, expression.expression, fill_pattern + ) + + def summarize_sql(self, expression: exp.Summarize) -> str: + table = " TABLE" if expression.args.get("table") else "" + return f"SUMMARIZE{table} {self.sql(expression.this)}" + + def explodinggenerateseries_sql( + self, expression: exp.ExplodingGenerateSeries + ) -> str: + generate_series = exp.GenerateSeries(**expression.args) + + parent = expression.parent + if isinstance(parent, (exp.Alias, exp.TableAlias)): + parent = parent.parent + + if self.SUPPORTS_EXPLODING_PROJECTIONS and not isinstance( + parent, (exp.Table, exp.Unnest) + ): + return self.sql(exp.Unnest(expressions=[generate_series])) + + if isinstance(parent, exp.Select): + self.unsupported("GenerateSeries projection unnesting is not supported.") + + return self.sql(generate_series) + + def arrayconcat_sql( + self, expression: exp.ArrayConcat, name: str = "ARRAY_CONCAT" + ) -> str: + exprs = expression.expressions + if not self.ARRAY_CONCAT_IS_VAR_LEN: + if len(exprs) == 0: + rhs: t.Union[str, exp.Expression] = exp.Array(expressions=[]) + else: + rhs = reduce( + lambda x, y: exp.ArrayConcat(this=x, expressions=[y]), exprs + ) + else: + rhs = self.expressions(expression) # type: ignore + + return self.func(name, expression.this, rhs or None) + + def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str: + if self.SUPPORTS_CONVERT_TIMEZONE: + return self.function_fallback_sql(expression) + + source_tz = expression.args.get("source_tz") + target_tz = expression.args.get("target_tz") + timestamp = expression.args.get("timestamp") + + if source_tz and timestamp: + timestamp = exp.AtTimeZone( + this=exp.cast(timestamp, exp.DataType.Type.TIMESTAMPNTZ), zone=source_tz + ) + + expr = exp.AtTimeZone(this=timestamp, zone=target_tz) + + return self.sql(expr) + + def json_sql(self, expression: exp.JSON) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + + _with = expression.args.get("with_") + + if _with is None: + with_sql = "" + elif not _with: + with_sql = " WITHOUT" + else: + with_sql = " WITH" + + unique_sql = " UNIQUE KEYS" if expression.args.get("unique") else "" + + return f"JSON{this}{with_sql}{unique_sql}" + + def jsonvalue_sql(self, expression: exp.JSONValue) -> str: + def _generate_on_options(arg: t.Any) -> str: + return arg if isinstance(arg, str) else f"DEFAULT {self.sql(arg)}" + + path = self.sql(expression, "path") + returning = self.sql(expression, "returning") + returning = f" RETURNING {returning}" if returning else "" + + on_condition = self.sql(expression, "on_condition") + on_condition = f" {on_condition}" if on_condition else "" + + return self.func( + "JSON_VALUE", expression.this, f"{path}{returning}{on_condition}" + ) + + def conditionalinsert_sql(self, expression: exp.ConditionalInsert) -> str: + else_ = "ELSE " if expression.args.get("else_") else "" + condition = self.sql(expression, "expression") + condition = f"WHEN {condition} THEN " if condition else else_ + insert = self.sql(expression, "this")[len("INSERT") :].strip() + return f"{condition}{insert}" + + def multitableinserts_sql(self, expression: exp.MultitableInserts) -> str: + kind = self.sql(expression, "kind") + expressions = self.seg(self.expressions(expression, sep=" ")) + res = f"INSERT {kind}{expressions}{self.seg(self.sql(expression, 'source'))}" + return res + + def oncondition_sql(self, expression: exp.OnCondition) -> str: + # Static options like "NULL ON ERROR" are stored as strings, in contrast to "DEFAULT ON ERROR" + empty = expression.args.get("empty") + empty = ( + f"DEFAULT {empty} ON EMPTY" + if isinstance(empty, exp.Expression) + else self.sql(expression, "empty") + ) + + error = expression.args.get("error") + error = ( + f"DEFAULT {error} ON ERROR" + if isinstance(error, exp.Expression) + else self.sql(expression, "error") + ) + + if error and empty: + error = ( + f"{empty} {error}" + if self.dialect.ON_CONDITION_EMPTY_BEFORE_ERROR + else f"{error} {empty}" + ) + empty = "" + + null = self.sql(expression, "null") + + return f"{empty}{error}{null}" + + def jsonextractquote_sql(self, expression: exp.JSONExtractQuote) -> str: + scalar = " ON SCALAR STRING" if expression.args.get("scalar") else "" + return f"{self.sql(expression, 'option')} QUOTES{scalar}" + + def jsonexists_sql(self, expression: exp.JSONExists) -> str: + this = self.sql(expression, "this") + path = self.sql(expression, "path") + + passing = self.expressions(expression, "passing") + passing = f" PASSING {passing}" if passing else "" + + on_condition = self.sql(expression, "on_condition") + on_condition = f" {on_condition}" if on_condition else "" + + path = f"{path}{passing}{on_condition}" + + return self.func("JSON_EXISTS", this, path) + + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: + array_agg = self.function_fallback_sql(expression) + + # Add a NULL FILTER on the column to mimic the results going from a dialect that excludes nulls + # on ARRAY_AGG (e.g Spark) to one that doesn't (e.g. DuckDB) + if self.dialect.ARRAY_AGG_INCLUDES_NULLS and expression.args.get( + "nulls_excluded" + ): + parent = expression.parent + if isinstance(parent, exp.Filter): + parent_cond = parent.expression.this + parent_cond.replace( + parent_cond.and_(expression.this.is_(exp.null()).not_()) + ) + else: + this = expression.this + # Do not add the filter if the input is not a column (e.g. literal, struct etc) + if this.find(exp.Column): + # DISTINCT is already present in the agg function, do not propagate it to FILTER as well + this_sql = ( + self.expressions(this) + if isinstance(this, exp.Distinct) + else self.sql(expression, "this") + ) + + array_agg = f"{array_agg} FILTER(WHERE {this_sql} IS NOT NULL)" + + return array_agg + + def slice_sql(self, expression: exp.Slice) -> str: + step = self.sql(expression, "step") + end = self.sql(expression.expression) + begin = self.sql(expression.this) + + sql = f"{end}:{step}" if step else end + return f"{begin}:{sql}" if sql else f"{begin}:" + + def apply_sql(self, expression: exp.Apply) -> str: + this = self.sql(expression, "this") + expr = self.sql(expression, "expression") + + return f"{this} APPLY({expr})" + + def _grant_or_revoke_sql( + self, + expression: exp.Grant | exp.Revoke, + keyword: str, + preposition: str, + grant_option_prefix: str = "", + grant_option_suffix: str = "", + ) -> str: + privileges_sql = self.expressions(expression, key="privileges", flat=True) + + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + + securable = self.sql(expression, "securable") + securable = f" {securable}" if securable else "" + + principals = self.expressions(expression, key="principals", flat=True) + + if not expression.args.get("grant_option"): + grant_option_prefix = grant_option_suffix = "" + + # cascade for revoke only + cascade = self.sql(expression, "cascade") + cascade = f" {cascade}" if cascade else "" + + return f"{keyword} {grant_option_prefix}{privileges_sql} ON{kind}{securable} {preposition} {principals}{grant_option_suffix}{cascade}" + + def grant_sql(self, expression: exp.Grant) -> str: + return self._grant_or_revoke_sql( + expression, + keyword="GRANT", + preposition="TO", + grant_option_suffix=" WITH GRANT OPTION", + ) + + def revoke_sql(self, expression: exp.Revoke) -> str: + return self._grant_or_revoke_sql( + expression, + keyword="REVOKE", + preposition="FROM", + grant_option_prefix="GRANT OPTION FOR ", + ) + + def grantprivilege_sql(self, expression: exp.GrantPrivilege): + this = self.sql(expression, "this") + columns = self.expressions(expression, flat=True) + columns = f"({columns})" if columns else "" + + return f"{this}{columns}" + + def grantprincipal_sql(self, expression: exp.GrantPrincipal): + this = self.sql(expression, "this") + + kind = self.sql(expression, "kind") + kind = f"{kind} " if kind else "" + + return f"{kind}{this}" + + def columns_sql(self, expression: exp.Columns): + func = self.function_fallback_sql(expression) + if expression.args.get("unpack"): + func = f"*{func}" + + return func + + def overlay_sql(self, expression: exp.Overlay): + this = self.sql(expression, "this") + expr = self.sql(expression, "expression") + from_sql = self.sql(expression, "from_") + for_sql = self.sql(expression, "for_") + for_sql = f" FOR {for_sql}" if for_sql else "" + + return f"OVERLAY({this} PLACING {expr} FROM {from_sql}{for_sql})" + + @unsupported_args("format") + def todouble_sql(self, expression: exp.ToDouble) -> str: + return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE)) + + def string_sql(self, expression: exp.String) -> str: + this = expression.this + zone = expression.args.get("zone") + + if zone: + # This is a BigQuery specific argument for STRING(, ) + # BigQuery stores timestamps internally as UTC, so ConvertTimezone is used with UTC + # set for source_tz to transpile the time conversion before the STRING cast + this = exp.ConvertTimezone( + source_tz=exp.Literal.string("UTC"), target_tz=zone, timestamp=this + ) + + return self.sql(exp.cast(this, exp.DataType.Type.VARCHAR)) + + def median_sql(self, expression: exp.Median): + if not self.SUPPORTS_MEDIAN: + return self.sql( + exp.PercentileCont( + this=expression.this, expression=exp.Literal.number(0.5) + ) + ) + + return self.function_fallback_sql(expression) + + def overflowtruncatebehavior_sql( + self, expression: exp.OverflowTruncateBehavior + ) -> str: + filler = self.sql(expression, "this") + filler = f" {filler}" if filler else "" + with_count = ( + "WITH COUNT" if expression.args.get("with_count") else "WITHOUT COUNT" + ) + return f"TRUNCATE{filler} {with_count}" + + def unixseconds_sql(self, expression: exp.UnixSeconds) -> str: + if self.SUPPORTS_UNIX_SECONDS: + return self.function_fallback_sql(expression) + + start_ts = exp.cast( + exp.Literal.string("1970-01-01 00:00:00+00"), + to=exp.DataType.Type.TIMESTAMPTZ, + ) + + return self.sql( + exp.TimestampDiff( + this=expression.this, expression=start_ts, unit=exp.var("SECONDS") + ) + ) + + def arraysize_sql(self, expression: exp.ArraySize) -> str: + dim = expression.expression + + # For dialects that don't support the dimension arg, we can safely transpile it's default value (1st dimension) + if dim and self.ARRAY_SIZE_DIM_REQUIRED is None: + if not (dim.is_int and dim.name == "1"): + self.unsupported("Cannot transpile dimension argument for ARRAY_LENGTH") + dim = None + + # If dimension is required but not specified, default initialize it + if self.ARRAY_SIZE_DIM_REQUIRED and not dim: + dim = exp.Literal.number(1) + + return self.func(self.ARRAY_SIZE_NAME, expression.this, dim) + + def attach_sql(self, expression: exp.Attach) -> str: + this = self.sql(expression, "this") + exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" + expressions = self.expressions(expression) + expressions = f" ({expressions})" if expressions else "" + + return f"ATTACH{exists_sql} {this}{expressions}" + + def detach_sql(self, expression: exp.Detach) -> str: + this = self.sql(expression, "this") + # the DATABASE keyword is required if IF EXISTS is set + # without it, DuckDB throws an error: Parser Error: syntax error at or near "exists" (Line Number: 1) + # ref: https://duckdb.org/docs/stable/sql/statements/attach.html#detach-syntax + exists_sql = " DATABASE IF EXISTS" if expression.args.get("exists") else "" + + return f"DETACH{exists_sql} {this}" + + def attachoption_sql(self, expression: exp.AttachOption) -> str: + this = self.sql(expression, "this") + value = self.sql(expression, "expression") + value = f" {value}" if value else "" + return f"{this}{value}" + + def watermarkcolumnconstraint_sql( + self, expression: exp.WatermarkColumnConstraint + ) -> str: + return f"WATERMARK FOR {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" + + def encodeproperty_sql(self, expression: exp.EncodeProperty) -> str: + encode = "KEY ENCODE" if expression.args.get("key") else "ENCODE" + encode = f"{encode} {self.sql(expression, 'this')}" + + properties = expression.args.get("properties") + if properties: + encode = f"{encode} {self.properties(properties)}" + + return encode + + def includeproperty_sql(self, expression: exp.IncludeProperty) -> str: + this = self.sql(expression, "this") + include = f"INCLUDE {this}" + + column_def = self.sql(expression, "column_def") + if column_def: + include = f"{include} {column_def}" + + alias = self.sql(expression, "alias") + if alias: + include = f"{include} AS {alias}" + + return include + + def xmlelement_sql(self, expression: exp.XMLElement) -> str: + name = f"NAME {self.sql(expression, 'this')}" + return self.func("XMLELEMENT", name, *expression.expressions) + + def xmlkeyvalueoption_sql(self, expression: exp.XMLKeyValueOption) -> str: + this = self.sql(expression, "this") + expr = self.sql(expression, "expression") + expr = f"({expr})" if expr else "" + return f"{this}{expr}" + + def partitionbyrangeproperty_sql( + self, expression: exp.PartitionByRangeProperty + ) -> str: + partitions = self.expressions(expression, "partition_expressions") + create = self.expressions(expression, "create_expressions") + return f"PARTITION BY RANGE {self.wrap(partitions)} {self.wrap(create)}" + + def partitionbyrangepropertydynamic_sql( + self, expression: exp.PartitionByRangePropertyDynamic + ) -> str: + start = self.sql(expression, "start") + end = self.sql(expression, "end") + + every = expression.args["every"] + if isinstance(every, exp.Interval) and every.this.is_string: + every.this.replace(exp.Literal.number(every.name)) + + return f"START {self.wrap(start)} END {self.wrap(end)} EVERY {self.wrap(self.sql(every))}" + + def unpivotcolumns_sql(self, expression: exp.UnpivotColumns) -> str: + name = self.sql(expression, "this") + values = self.expressions(expression, flat=True) + + return f"NAME {name} VALUE {values}" + + def analyzesample_sql(self, expression: exp.AnalyzeSample) -> str: + kind = self.sql(expression, "kind") + sample = self.sql(expression, "sample") + return f"SAMPLE {sample} {kind}" + + def analyzestatistics_sql(self, expression: exp.AnalyzeStatistics) -> str: + kind = self.sql(expression, "kind") + option = self.sql(expression, "option") + option = f" {option}" if option else "" + this = self.sql(expression, "this") + this = f" {this}" if this else "" + columns = self.expressions(expression) + columns = f" {columns}" if columns else "" + return f"{kind}{option} STATISTICS{this}{columns}" + + def analyzehistogram_sql(self, expression: exp.AnalyzeHistogram) -> str: + this = self.sql(expression, "this") + columns = self.expressions(expression) + inner_expression = self.sql(expression, "expression") + inner_expression = f" {inner_expression}" if inner_expression else "" + update_options = self.sql(expression, "update_options") + update_options = f" {update_options} UPDATE" if update_options else "" + return f"{this} HISTOGRAM ON {columns}{inner_expression}{update_options}" + + def analyzedelete_sql(self, expression: exp.AnalyzeDelete) -> str: + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + return f"DELETE{kind} STATISTICS" + + def analyzelistchainedrows_sql(self, expression: exp.AnalyzeListChainedRows) -> str: + inner_expression = self.sql(expression, "expression") + return f"LIST CHAINED ROWS{inner_expression}" + + def analyzevalidate_sql(self, expression: exp.AnalyzeValidate) -> str: + kind = self.sql(expression, "kind") + this = self.sql(expression, "this") + this = f" {this}" if this else "" + inner_expression = self.sql(expression, "expression") + return f"VALIDATE {kind}{this}{inner_expression}" + + def analyze_sql(self, expression: exp.Analyze) -> str: + options = self.expressions(expression, key="options", sep=" ") + options = f" {options}" if options else "" + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + this = self.sql(expression, "this") + this = f" {this}" if this else "" + mode = self.sql(expression, "mode") + mode = f" {mode}" if mode else "" + properties = self.sql(expression, "properties") + properties = f" {properties}" if properties else "" + partition = self.sql(expression, "partition") + partition = f" {partition}" if partition else "" + inner_expression = self.sql(expression, "expression") + inner_expression = f" {inner_expression}" if inner_expression else "" + return f"ANALYZE{options}{kind}{this}{partition}{mode}{inner_expression}{properties}" + + def xmltable_sql(self, expression: exp.XMLTable) -> str: + this = self.sql(expression, "this") + namespaces = self.expressions(expression, key="namespaces") + namespaces = f"XMLNAMESPACES({namespaces}), " if namespaces else "" + passing = self.expressions(expression, key="passing") + passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else "" + columns = self.expressions(expression, key="columns") + columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else "" + by_ref = ( + f"{self.sep()}RETURNING SEQUENCE BY REF" + if expression.args.get("by_ref") + else "" + ) + return f"XMLTABLE({self.sep('')}{self.indent(namespaces + this + passing + by_ref + columns)}{self.seg(')', sep='')}" + + def xmlnamespace_sql(self, expression: exp.XMLNamespace) -> str: + this = self.sql(expression, "this") + return this if isinstance(expression.this, exp.Alias) else f"DEFAULT {this}" + + def export_sql(self, expression: exp.Export) -> str: + this = self.sql(expression, "this") + connection = self.sql(expression, "connection") + connection = f"WITH CONNECTION {connection} " if connection else "" + options = self.sql(expression, "options") + return f"EXPORT DATA {connection}{options} AS {this}" + + def declare_sql(self, expression: exp.Declare) -> str: + return f"DECLARE {self.expressions(expression, flat=True)}" + + def declareitem_sql(self, expression: exp.DeclareItem) -> str: + variable = self.sql(expression, "this") + default = self.sql(expression, "default") + default = f" = {default}" if default else "" + + kind = self.sql(expression, "kind") + if isinstance(expression.args.get("kind"), exp.Schema): + kind = f"TABLE {kind}" + + return f"{variable} AS {kind}{default}" + + def recursivewithsearch_sql(self, expression: exp.RecursiveWithSearch) -> str: + kind = self.sql(expression, "kind") + this = self.sql(expression, "this") + set = self.sql(expression, "expression") + using = self.sql(expression, "using") + using = f" USING {using}" if using else "" + + kind_sql = kind if kind == "CYCLE" else f"SEARCH {kind} FIRST BY" + + return f"{kind_sql} {this} SET {set}{using}" + + def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str: + params = self.expressions(expression, key="params", flat=True) + return self.func(expression.name, *expression.expressions) + f"({params})" + + def anonymousaggfunc_sql(self, expression: exp.AnonymousAggFunc) -> str: + return self.func(expression.name, *expression.expressions) + + def combinedaggfunc_sql(self, expression: exp.CombinedAggFunc) -> str: + return self.anonymousaggfunc_sql(expression) + + def combinedparameterizedagg_sql( + self, expression: exp.CombinedParameterizedAgg + ) -> str: + return self.parameterizedagg_sql(expression) + + def show_sql(self, expression: exp.Show) -> str: + self.unsupported("Unsupported SHOW statement") + return "" + + def install_sql(self, expression: exp.Install) -> str: + self.unsupported("Unsupported INSTALL statement") + return "" + + def get_put_sql(self, expression: exp.Put | exp.Get) -> str: + # Snowflake GET/PUT statements: + # PUT + # GET + props = expression.args.get("properties") + props_sql = ( + self.properties(props, prefix=" ", sep=" ", wrapped=False) if props else "" + ) + this = self.sql(expression, "this") + target = self.sql(expression, "target") + + if isinstance(expression, exp.Put): + return f"PUT {this} {target}{props_sql}" + else: + return f"GET {target} {this}{props_sql}" + + def translatecharacters_sql(self, expression: exp.TranslateCharacters): + this = self.sql(expression, "this") + expr = self.sql(expression, "expression") + with_error = " WITH ERROR" if expression.args.get("with_error") else "" + return f"TRANSLATE({this} USING {expr}{with_error})" + + def decodecase_sql(self, expression: exp.DecodeCase) -> str: + if self.SUPPORTS_DECODE_CASE: + return self.func("DECODE", *expression.expressions) + + expression, *expressions = expression.expressions + + ifs = [] + for search, result in zip(expressions[::2], expressions[1::2]): + if isinstance(search, exp.Literal): + ifs.append(exp.If(this=expression.eq(search), true=result)) + elif isinstance(search, exp.Null): + ifs.append(exp.If(this=expression.is_(exp.Null()), true=result)) + else: + if isinstance(search, exp.Binary): + search = exp.paren(search) + + cond = exp.or_( + expression.eq(search), + exp.and_( + expression.is_(exp.Null()), search.is_(exp.Null()), copy=False + ), + copy=False, + ) + ifs.append(exp.If(this=cond, true=result)) + + case = exp.Case( + ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None + ) + return self.sql(case) + + def semanticview_sql(self, expression: exp.SemanticView) -> str: + this = self.sql(expression, "this") + this = self.seg(this, sep="") + dimensions = self.expressions( + expression, "dimensions", dynamic=True, skip_first=True, skip_last=True + ) + dimensions = self.seg(f"DIMENSIONS {dimensions}") if dimensions else "" + metrics = self.expressions( + expression, "metrics", dynamic=True, skip_first=True, skip_last=True + ) + metrics = self.seg(f"METRICS {metrics}") if metrics else "" + facts = self.expressions( + expression, "facts", dynamic=True, skip_first=True, skip_last=True + ) + facts = self.seg(f"FACTS {facts}") if facts else "" + where = self.sql(expression, "where") + where = self.seg(f"WHERE {where}") if where else "" + body = self.indent(this + metrics + dimensions + facts + where, skip_first=True) + return f"SEMANTIC_VIEW({body}{self.seg(')', sep='')}" + + def getextract_sql(self, expression: exp.GetExtract) -> str: + this = expression.this + expr = expression.expression + + if not this.type or not expression.type: + from bigframes_vendored.sqlglot.optimizer.annotate_types import ( + annotate_types, + ) + + this = annotate_types(this, dialect=self.dialect) + + if this.is_type(*(exp.DataType.Type.ARRAY, exp.DataType.Type.MAP)): + return self.sql(exp.Bracket(this=this, expressions=[expr])) + + return self.sql( + exp.JSONExtract(this=this, expression=self.dialect.to_json_path(expr)) + ) + + def datefromunixdate_sql(self, expression: exp.DateFromUnixDate) -> str: + return self.sql( + exp.DateAdd( + this=exp.cast(exp.Literal.string("1970-01-01"), exp.DataType.Type.DATE), + expression=expression.this, + unit=exp.var("DAY"), + ) + ) + + def space_sql(self: Generator, expression: exp.Space) -> str: + return self.sql(exp.Repeat(this=exp.Literal.string(" "), times=expression.this)) + + def buildproperty_sql(self, expression: exp.BuildProperty) -> str: + return f"BUILD {self.sql(expression, 'this')}" + + def refreshtriggerproperty_sql(self, expression: exp.RefreshTriggerProperty) -> str: + method = self.sql(expression, "method") + kind = expression.args.get("kind") + if not kind: + return f"REFRESH {method}" + + every = self.sql(expression, "every") + unit = self.sql(expression, "unit") + every = f" EVERY {every} {unit}" if every else "" + starts = self.sql(expression, "starts") + starts = f" STARTS {starts}" if starts else "" + + return f"REFRESH {method} ON {kind}{every}{starts}" + + def modelattribute_sql(self, expression: exp.ModelAttribute) -> str: + self.unsupported("The model!attribute syntax is not supported") + return "" + + def directorystage_sql(self, expression: exp.DirectoryStage) -> str: + return self.func("DIRECTORY", expression.this) + + def uuid_sql(self, expression: exp.Uuid) -> str: + is_string = expression.args.get("is_string", False) + uuid_func_sql = self.func("UUID") + + if is_string and not self.dialect.UUID_IS_STRING_TYPE: + return self.sql( + exp.cast(uuid_func_sql, exp.DataType.Type.VARCHAR, dialect=self.dialect) + ) + + return uuid_func_sql + + def initcap_sql(self, expression: exp.Initcap) -> str: + delimiters = expression.expression + + if delimiters: + # do not generate delimiters arg if we are round-tripping from default delimiters + if ( + delimiters.is_string + and delimiters.this == self.dialect.INITCAP_DEFAULT_DELIMITER_CHARS + ): + delimiters = None + elif not self.dialect.INITCAP_SUPPORTS_CUSTOM_DELIMITERS: + self.unsupported("INITCAP does not support custom delimiters") + delimiters = None + + return self.func("INITCAP", expression.this, delimiters) + + def localtime_sql(self, expression: exp.Localtime) -> str: + this = expression.this + return self.func("LOCALTIME", this) if this else "LOCALTIME" + + def localtimestamp_sql(self, expression: exp.Localtime) -> str: + this = expression.this + return self.func("LOCALTIMESTAMP", this) if this else "LOCALTIMESTAMP" + + def weekstart_sql(self, expression: exp.WeekStart) -> str: + this = expression.this.name.upper() + if self.dialect.WEEK_OFFSET == -1 and this == "SUNDAY": + # BigQuery specific optimization since WEEK(SUNDAY) == WEEK + return "WEEK" + + return self.func("WEEK", expression.this) diff --git a/third_party/bigframes_vendored/sqlglot/helper.py b/third_party/bigframes_vendored/sqlglot/helper.py new file mode 100644 index 0000000000..b06b96947d --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/helper.py @@ -0,0 +1,535 @@ +from __future__ import annotations + +from collections.abc import Collection, Set +from copy import copy +import datetime +from difflib import get_close_matches +from enum import Enum +import inspect +from itertools import count +import logging +import re +import sys +import typing as t + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot import exp + from bigframes_vendored.sqlglot._typing import A, E, T + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + from bigframes_vendored.sqlglot.expressions import Expression + + +CAMEL_CASE_PATTERN = re.compile("(? t.Any: + return classmethod(self.fget).__get__(None, owner)() # type: ignore + + +def suggest_closest_match_and_fail( + kind: str, + word: str, + possibilities: t.Iterable[str], +) -> None: + close_matches = get_close_matches(word, possibilities, n=1) + + similar = seq_get(close_matches, 0) or "" + if similar: + similar = f" Did you mean {similar}?" + + raise ValueError(f"Unknown {kind} '{word}'.{similar}") + + +def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: + """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" + try: + return seq[index] + except IndexError: + return None + + +@t.overload +def ensure_list(value: t.Collection[T]) -> t.List[T]: + ... + + +@t.overload +def ensure_list(value: None) -> t.List: + ... + + +@t.overload +def ensure_list(value: T) -> t.List[T]: + ... + + +def ensure_list(value): + """ + Ensures that a value is a list, otherwise casts or wraps it into one. + + Args: + value: The value of interest. + + Returns: + The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. + """ + if value is None: + return [] + if isinstance(value, (list, tuple)): + return list(value) + + return [value] + + +@t.overload +def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: + ... + + +@t.overload +def ensure_collection(value: T) -> t.Collection[T]: + ... + + +def ensure_collection(value): + """ + Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. + + Args: + value: The value of interest. + + Returns: + The value if it's a collection, or else the value wrapped in a list. + """ + if value is None: + return [] + return ( + value + if isinstance(value, Collection) and not isinstance(value, (str, bytes)) + else [value] + ) + + +def csv(*args: str, sep: str = ", ") -> str: + """ + Formats any number of string arguments as CSV. + + Args: + args: The string arguments to format. + sep: The argument separator. + + Returns: + The arguments formatted as a CSV string. + """ + return sep.join(arg for arg in args if arg) + + +def subclasses( + module_name: str, + classes: t.Type | t.Tuple[t.Type, ...], + exclude: t.Set[t.Type] = set(), +) -> t.List[t.Type]: + """ + Returns all subclasses for a collection of classes, possibly excluding some of them. + + Args: + module_name: The name of the module to search for subclasses in. + classes: Class(es) we want to find the subclasses of. + exclude: Classes we want to exclude from the returned list. + + Returns: + The target subclasses. + """ + return [ + obj + for _, obj in inspect.getmembers( + sys.modules[module_name], + lambda obj: inspect.isclass(obj) + and issubclass(obj, classes) + and obj not in exclude, + ) + ] + + +def apply_index_offset( + this: exp.Expression, + expressions: t.List[E], + offset: int, + dialect: DialectType = None, +) -> t.List[E]: + """ + Applies an offset to a given integer literal expression. + + Args: + this: The target of the index. + expressions: The expression the offset will be applied to, wrapped in a list. + offset: The offset that will be applied. + dialect: the dialect of interest. + + Returns: + The original expression with the offset applied to it, wrapped in a list. If the provided + `expressions` argument contains more than one expression, it's returned unaffected. + """ + if not offset or len(expressions) != 1: + return expressions + + expression = expressions[0] + + from bigframes_vendored.sqlglot import exp + from bigframes_vendored.sqlglot.optimizer.annotate_types import annotate_types + from bigframes_vendored.sqlglot.optimizer.simplify import simplify + + if not this.type: + annotate_types(this, dialect=dialect) + + if t.cast(exp.DataType, this.type).this not in ( + exp.DataType.Type.UNKNOWN, + exp.DataType.Type.ARRAY, + ): + return expressions + + if not expression.type: + annotate_types(expression, dialect=dialect) + + if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: + logger.info("Applying array index offset (%s)", offset) + expression = simplify(expression + offset) + return [expression] + + return expressions + + +def camel_to_snake_case(name: str) -> str: + """Converts `name` from camelCase to snake_case and returns the result.""" + return CAMEL_CASE_PATTERN.sub("_", name).upper() + + +def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: + """ + Applies a transformation to a given expression until a fix point is reached. + + Args: + expression: The expression to be transformed. + func: The transformation to be applied. + + Returns: + The transformed expression. + """ + + while True: + start_hash = hash(expression) + expression = func(expression) + end_hash = hash(expression) + + if start_hash == end_hash: + break + + return expression + + +def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: + """ + Sorts a given directed acyclic graph in topological order. + + Args: + dag: The graph to be sorted. + + Returns: + A list that contains all of the graph's nodes in topological order. + """ + result = [] + + for node, deps in tuple(dag.items()): + for dep in deps: + if dep not in dag: + dag[dep] = set() + + while dag: + current = {node for node, deps in dag.items() if not deps} + + if not current: + raise ValueError("Cycle error") + + for node in current: + dag.pop(node) + + for deps in dag.values(): + deps -= current + + result.extend(sorted(current)) # type: ignore + + return result + + +def find_new_name(taken: t.Collection[str], base: str) -> str: + """ + Searches for a new name. + + Args: + taken: A collection of taken names. + base: Base name to alter. + + Returns: + The new, available name. + """ + if base not in taken: + return base + + i = 2 + new = f"{base}_{i}" + while new in taken: + i += 1 + new = f"{base}_{i}" + + return new + + +def is_int(text: str) -> bool: + return is_type(text, int) + + +def is_float(text: str) -> bool: + return is_type(text, float) + + +def is_type(text: str, target_type: t.Type) -> bool: + try: + target_type(text) + return True + except ValueError: + return False + + +def name_sequence(prefix: str) -> t.Callable[[], str]: + """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" + sequence = count() + return lambda: f"{prefix}{next(sequence)}" + + +def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: + """Returns a dictionary created from an object's attributes.""" + return { + **{ + k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items() + }, + **kwargs, + } + + +def split_num_words( + value: str, sep: str, min_num_words: int, fill_from_start: bool = True +) -> t.List[t.Optional[str]]: + """ + Perform a split on a value and return N words as a result with `None` used for words that don't exist. + + Args: + value: The value to be split. + sep: The value to use to split on. + min_num_words: The minimum number of words that are going to be in the result. + fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. + + Examples: + >>> split_num_words("db.table", ".", 3) + [None, 'db', 'table'] + >>> split_num_words("db.table", ".", 3, fill_from_start=False) + ['db', 'table', None] + >>> split_num_words("db.table", ".", 1) + ['db', 'table'] + + Returns: + The list of words returned by `split`, possibly augmented by a number of `None` values. + """ + words = value.split(sep) + if fill_from_start: + return [None] * (min_num_words - len(words)) + words + return words + [None] * (min_num_words - len(words)) + + +def is_iterable(value: t.Any) -> bool: + """ + Checks if the value is an iterable, excluding the types `str` and `bytes`. + + Examples: + >>> is_iterable([1,2]) + True + >>> is_iterable("test") + False + + Args: + value: The value to check if it is an iterable. + + Returns: + A `bool` value indicating if it is an iterable. + """ + from bigframes_vendored.sqlglot import Expression + + return hasattr(value, "__iter__") and not isinstance( + value, (str, bytes, Expression) + ) + + +def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: + """ + Flattens an iterable that can contain both iterable and non-iterable elements. Objects of + type `str` and `bytes` are not regarded as iterables. + + Examples: + >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) + [1, 2, 3, 4, 5, 'bla'] + >>> list(flatten([1, 2, 3])) + [1, 2, 3] + + Args: + values: The value to be flattened. + + Yields: + Non-iterable elements in `values`. + """ + for value in values: + if is_iterable(value): + yield from flatten(value) + else: + yield value + + +def dict_depth(d: t.Dict) -> int: + """ + Get the nesting depth of a dictionary. + + Example: + >>> dict_depth(None) + 0 + >>> dict_depth({}) + 1 + >>> dict_depth({"a": "b"}) + 1 + >>> dict_depth({"a": {}}) + 2 + >>> dict_depth({"a": {"b": {}}}) + 3 + """ + try: + return 1 + dict_depth(next(iter(d.values()))) + except AttributeError: + # d doesn't have attribute "values" + return 0 + except StopIteration: + # d.values() returns an empty sequence + return 1 + + +def first(it: t.Iterable[T]) -> T: + """Returns the first element from an iterable (useful for sets).""" + return next(i for i in it) + + +def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: + if isinstance(value, bool) or value is None: + return value + + # Coerce the value to boolean if it matches to the truthy/falsy values below + value_lower = value.lower() + if value_lower in ("true", "1"): + return True + if value_lower in ("false", "0"): + return False + + return value + + +def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: + """ + Merges a sequence of ranges, represented as tuples (low, high) whose values + belong to some totally-ordered set. + + Example: + >>> merge_ranges([(1, 3), (2, 6)]) + [(1, 6)] + """ + if not ranges: + return [] + + ranges = sorted(ranges) + + merged = [ranges[0]] + + for start, end in ranges[1:]: + last_start, last_end = merged[-1] + + if start <= last_end: + merged[-1] = (last_start, max(last_end, end)) + else: + merged.append((start, end)) + + return merged + + +def is_iso_date(text: str) -> bool: + try: + datetime.date.fromisoformat(text) + return True + except ValueError: + return False + + +def is_iso_datetime(text: str) -> bool: + try: + datetime.datetime.fromisoformat(text) + return True + except ValueError: + return False + + +# Interval units that operate on date components +DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} + + +def is_date_unit(expression: t.Optional[exp.Expression]) -> bool: + return expression is not None and expression.name.lower() in DATE_UNITS + + +K = t.TypeVar("K") +V = t.TypeVar("V") + + +class SingleValuedMapping(t.Mapping[K, V]): + """ + Mapping where all keys return the same value. + + This rigamarole is meant to avoid copying keys, which was originally intended + as an optimization while qualifying columns for tables with lots of columns. + """ + + def __init__(self, keys: t.Collection[K], value: V): + self._keys = keys if isinstance(keys, Set) else set(keys) + self._value = value + + def __getitem__(self, key: K) -> V: + if key in self._keys: + return self._value + raise KeyError(key) + + def __len__(self) -> int: + return len(self._keys) + + def __iter__(self) -> t.Iterator[K]: + return iter(self._keys) diff --git a/third_party/bigframes_vendored/sqlglot/jsonpath.py b/third_party/bigframes_vendored/sqlglot/jsonpath.py new file mode 100644 index 0000000000..916a68e17f --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/jsonpath.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot.errors import ParseError +import bigframes_vendored.sqlglot.expressions as exp +from bigframes_vendored.sqlglot.tokens import Token, Tokenizer, TokenType + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import Lit + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + + +class JSONPathTokenizer(Tokenizer): + SINGLE_TOKENS = { + "(": TokenType.L_PAREN, + ")": TokenType.R_PAREN, + "[": TokenType.L_BRACKET, + "]": TokenType.R_BRACKET, + ":": TokenType.COLON, + ",": TokenType.COMMA, + "-": TokenType.DASH, + ".": TokenType.DOT, + "?": TokenType.PLACEHOLDER, + "@": TokenType.PARAMETER, + "'": TokenType.QUOTE, + '"': TokenType.QUOTE, + "$": TokenType.DOLLAR, + "*": TokenType.STAR, + } + + KEYWORDS = { + "..": TokenType.DOT, + } + + IDENTIFIER_ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] + + VAR_TOKENS = { + TokenType.VAR, + } + + +def parse(path: str, dialect: DialectType = None) -> exp.JSONPath: + """Takes in a JSON path string and parses it into a JSONPath expression.""" + from bigframes_vendored.sqlglot.dialects import Dialect + + jsonpath_tokenizer = Dialect.get_or_raise(dialect).jsonpath_tokenizer() + tokens = jsonpath_tokenizer.tokenize(path) + size = len(tokens) + + i = 0 + + def _curr() -> t.Optional[TokenType]: + return tokens[i].token_type if i < size else None + + def _prev() -> Token: + return tokens[i - 1] + + def _advance() -> Token: + nonlocal i + i += 1 + return _prev() + + def _error(msg: str) -> str: + return f"{msg} at index {i}: {path}" + + @t.overload + def _match(token_type: TokenType, raise_unmatched: Lit[True] = True) -> Token: + pass + + @t.overload + def _match( + token_type: TokenType, raise_unmatched: Lit[False] = False + ) -> t.Optional[Token]: + pass + + def _match(token_type, raise_unmatched=False): + if _curr() == token_type: + return _advance() + if raise_unmatched: + raise ParseError(_error(f"Expected {token_type}")) + return None + + def _match_set(types: t.Collection[TokenType]) -> t.Optional[Token]: + return _advance() if _curr() in types else None + + def _parse_literal() -> t.Any: + token = _match(TokenType.STRING) or _match(TokenType.IDENTIFIER) + if token: + return token.text + if _match(TokenType.STAR): + return exp.JSONPathWildcard() + if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN): + script = _prev().text == "(" + start = i + + while True: + if _match(TokenType.L_BRACKET): + _parse_bracket() # nested call which we can throw away + if _curr() in (TokenType.R_BRACKET, None): + break + _advance() + + expr_type = exp.JSONPathScript if script else exp.JSONPathFilter + return expr_type(this=path[tokens[start].start : tokens[i].end]) + + number = "-" if _match(TokenType.DASH) else "" + + token = _match(TokenType.NUMBER) + if token: + number += token.text + + if number: + return int(number) + + return False + + def _parse_slice() -> t.Any: + start = _parse_literal() + end = _parse_literal() if _match(TokenType.COLON) else None + step = _parse_literal() if _match(TokenType.COLON) else None + + if end is None and step is None: + return start + + return exp.JSONPathSlice(start=start, end=end, step=step) + + def _parse_bracket() -> exp.JSONPathPart: + literal = _parse_slice() + + if isinstance(literal, str) or literal is not False: + indexes = [literal] + while _match(TokenType.COMMA): + literal = _parse_slice() + + if literal: + indexes.append(literal) + + if len(indexes) == 1: + if isinstance(literal, str): + node: exp.JSONPathPart = exp.JSONPathKey(this=indexes[0]) + elif isinstance(literal, exp.JSONPathPart) and isinstance( + literal, (exp.JSONPathScript, exp.JSONPathFilter) + ): + node = exp.JSONPathSelector(this=indexes[0]) + else: + node = exp.JSONPathSubscript(this=indexes[0]) + else: + node = exp.JSONPathUnion(expressions=indexes) + else: + raise ParseError(_error("Cannot have empty segment")) + + _match(TokenType.R_BRACKET, raise_unmatched=True) + + return node + + def _parse_var_text() -> str: + """ + Consumes & returns the text for a var. In BigQuery it's valid to have a key with spaces + in it, e.g JSON_QUERY(..., '$. a b c ') should produce a single JSONPathKey(' a b c '). + This is done by merging "consecutive" vars until a key separator is found (dot, colon etc) + or the path string is exhausted. + """ + prev_index = i - 2 + + while _match_set(jsonpath_tokenizer.VAR_TOKENS): + pass + + start = 0 if prev_index < 0 else tokens[prev_index].end + 1 + + if i >= len(tokens): + # This key is the last token for the path, so it's text is the remaining path + text = path[start:] + else: + text = path[start : tokens[i].start] + + return text + + # We canonicalize the JSON path AST so that it always starts with a + # "root" element, so paths like "field" will be generated as "$.field" + _match(TokenType.DOLLAR) + expressions: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] + + while _curr(): + if _match(TokenType.DOT) or _match(TokenType.COLON): + recursive = _prev().text == ".." + + if _match_set(jsonpath_tokenizer.VAR_TOKENS): + value: t.Optional[str | exp.JSONPathWildcard] = _parse_var_text() + elif _match(TokenType.IDENTIFIER): + value = _prev().text + elif _match(TokenType.STAR): + value = exp.JSONPathWildcard() + else: + value = None + + if recursive: + expressions.append(exp.JSONPathRecursive(this=value)) + elif value: + expressions.append(exp.JSONPathKey(this=value)) + else: + raise ParseError(_error("Expected key name or * after DOT")) + elif _match(TokenType.L_BRACKET): + expressions.append(_parse_bracket()) + elif _match_set(jsonpath_tokenizer.VAR_TOKENS): + expressions.append(exp.JSONPathKey(this=_parse_var_text())) + elif _match(TokenType.IDENTIFIER): + expressions.append(exp.JSONPathKey(this=_prev().text)) + elif _match(TokenType.STAR): + expressions.append(exp.JSONPathWildcard()) + else: + raise ParseError(_error(f"Unexpected {tokens[i].token_type}")) + + return exp.JSONPath(expressions=expressions) + + +JSON_PATH_PART_TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { + exp.JSONPathFilter: lambda _, e: f"?{e.this}", + exp.JSONPathKey: lambda self, e: self._jsonpathkey_sql(e), + exp.JSONPathRecursive: lambda _, e: f"..{e.this or ''}", + exp.JSONPathRoot: lambda *_: "$", + exp.JSONPathScript: lambda _, e: f"({e.this}", + exp.JSONPathSelector: lambda self, e: f"[{self.json_path_part(e.this)}]", + exp.JSONPathSlice: lambda self, e: ":".join( + "" if p is False else self.json_path_part(p) + for p in [e.args.get("start"), e.args.get("end"), e.args.get("step")] + if p is not None + ), + exp.JSONPathSubscript: lambda self, e: self._jsonpathsubscript_sql(e), + exp.JSONPathUnion: lambda self, e: f"[{','.join(self.json_path_part(p) for p in e.expressions)}]", + exp.JSONPathWildcard: lambda *_: "*", +} + +ALL_JSON_PATH_PARTS = set(JSON_PATH_PART_TRANSFORMS) diff --git a/third_party/bigframes_vendored/sqlglot/lineage.py b/third_party/bigframes_vendored/sqlglot/lineage.py new file mode 100644 index 0000000000..a80d65c649 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/lineage.py @@ -0,0 +1,453 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import json +import logging +import typing as t + +from bigframes_vendored.sqlglot import exp, maybe_parse, Schema +from bigframes_vendored.sqlglot.errors import SqlglotError +from bigframes_vendored.sqlglot.optimizer import ( + build_scope, + find_all_in_scope, + normalize_identifiers, + qualify, + Scope, +) +from bigframes_vendored.sqlglot.optimizer.scope import ScopeType + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + +logger = logging.getLogger("sqlglot") + + +@dataclass(frozen=True) +class Node: + name: str + expression: exp.Expression + source: exp.Expression + downstream: t.List[Node] = field(default_factory=list) + source_name: str = "" + reference_node_name: str = "" + + def walk(self) -> t.Iterator[Node]: + yield self + + for d in self.downstream: + yield from d.walk() + + def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: + nodes = {} + edges = [] + + for node in self.walk(): + if isinstance(node.expression, exp.Table): + label = f"FROM {node.expression.this}" + title = f"
SELECT {node.name} FROM {node.expression.this}
" + group = 1 + else: + label = node.expression.sql(pretty=True, dialect=dialect) + source = node.source.transform( + lambda n: ( + exp.Tag(this=n, prefix="", postfix="") + if n is node.expression + else n + ), + copy=False, + ).sql(pretty=True, dialect=dialect) + title = f"
{source}
" + group = 0 + + node_id = id(node) + + nodes[node_id] = { + "id": node_id, + "label": label, + "title": title, + "group": group, + } + + for d in node.downstream: + edges.append({"from": node_id, "to": id(d)}) + return GraphHTML(nodes, edges, **opts) + + +def lineage( + column: str | exp.Column, + sql: str | exp.Expression, + schema: t.Optional[t.Dict | Schema] = None, + sources: t.Optional[t.Mapping[str, str | exp.Query]] = None, + dialect: DialectType = None, + scope: t.Optional[Scope] = None, + trim_selects: bool = True, + copy: bool = True, + **kwargs, +) -> Node: + """Build the lineage graph for a column of a SQL query. + + Args: + column: The column to build the lineage for. + sql: The SQL string or expression. + schema: The schema of tables. + sources: A mapping of queries which will be used to continue building lineage. + dialect: The dialect of input SQL. + scope: A pre-created scope to use instead. + trim_selects: Whether to clean up selects by trimming to only relevant columns. + copy: Whether to copy the Expression arguments. + **kwargs: Qualification optimizer kwargs. + + Returns: + A lineage node. + """ + + expression = maybe_parse(sql, copy=copy, dialect=dialect) + column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name + + if sources: + expression = exp.expand( + expression, + { + k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect)) + for k, v in sources.items() + }, + dialect=dialect, + copy=copy, + ) + + if not scope: + expression = qualify.qualify( + expression, + dialect=dialect, + schema=schema, + **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore + ) + + scope = build_scope(expression) + + if not scope: + raise SqlglotError("Cannot build lineage, sql must be SELECT") + + if not any(select.alias_or_name == column for select in scope.expression.selects): + raise SqlglotError(f"Cannot find column '{column}' in query.") + + return to_node(column, scope, dialect, trim_selects=trim_selects) + + +def to_node( + column: str | int, + scope: Scope, + dialect: DialectType, + scope_name: t.Optional[str] = None, + upstream: t.Optional[Node] = None, + source_name: t.Optional[str] = None, + reference_node_name: t.Optional[str] = None, + trim_selects: bool = True, +) -> Node: + # Find the specific select clause that is the source of the column we want. + # This can either be a specific, named select or a generic `*` clause. + select = ( + scope.expression.selects[column] + if isinstance(column, int) + else next( + ( + select + for select in scope.expression.selects + if select.alias_or_name == column + ), + exp.Star() if scope.expression.is_star else scope.expression, + ) + ) + + if isinstance(scope.expression, exp.Subquery): + for source in scope.subquery_scopes: + return to_node( + column, + scope=source, + dialect=dialect, + upstream=upstream, + source_name=source_name, + reference_node_name=reference_node_name, + trim_selects=trim_selects, + ) + if isinstance(scope.expression, exp.SetOperation): + name = type(scope.expression).__name__.upper() + upstream = upstream or Node( + name=name, source=scope.expression, expression=select + ) + + index = ( + column + if isinstance(column, int) + else next( + ( + i + for i, select in enumerate(scope.expression.selects) + if select.alias_or_name == column or select.is_star + ), + -1, # mypy will not allow a None here, but a negative index should never be returned + ) + ) + + if index == -1: + raise ValueError(f"Could not find {column} in {scope.expression}") + + for s in scope.union_scopes: + to_node( + index, + scope=s, + dialect=dialect, + upstream=upstream, + source_name=source_name, + reference_node_name=reference_node_name, + trim_selects=trim_selects, + ) + + return upstream + + if trim_selects and isinstance(scope.expression, exp.Select): + # For better ergonomics in our node labels, replace the full select with + # a version that has only the column we care about. + # "x", SELECT x, y FROM foo + # => "x", SELECT x FROM foo + source = t.cast(exp.Expression, scope.expression.select(select, append=False)) + else: + source = scope.expression + + # Create the node for this step in the lineage chain, and attach it to the previous one. + node = Node( + name=f"{scope_name}.{column}" if scope_name else str(column), + source=source, + expression=select, + source_name=source_name or "", + reference_node_name=reference_node_name or "", + ) + + if upstream: + upstream.downstream.append(node) + + subquery_scopes = { + id(subquery_scope.expression): subquery_scope + for subquery_scope in scope.subquery_scopes + } + + for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): + subquery_scope = subquery_scopes.get(id(subquery)) + if not subquery_scope: + logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") + continue + + for name in subquery.named_selects: + to_node( + name, + scope=subquery_scope, + dialect=dialect, + upstream=node, + trim_selects=trim_selects, + ) + + # if the select is a star add all scope sources as downstreams + if isinstance(select, exp.Star): + for source in scope.sources.values(): + if isinstance(source, Scope): + source = source.expression + node.downstream.append( + Node(name=select.sql(comments=False), source=source, expression=source) + ) + + # Find all columns that went into creating this one to list their lineage nodes. + source_columns = set(find_all_in_scope(select, exp.Column)) + + # If the source is a UDTF find columns used in the UDTF to generate the table + if isinstance(source, exp.UDTF): + source_columns |= set(source.find_all(exp.Column)) + derived_tables = [ + source.expression.parent + for source in scope.sources.values() + if isinstance(source, Scope) and source.is_derived_table + ] + else: + derived_tables = scope.derived_tables + + source_names = { + dt.alias: dt.comments[0].split()[1] + for dt in derived_tables + if dt.comments and dt.comments[0].startswith("source: ") + } + + pivots = scope.pivots + pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None + if pivot: + # For each aggregation function, the pivot creates a new column for each field in category + # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, + # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' + # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs + # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest + # in the lineage, so lookup the pivot column name by index and map that with the columns used + # in the aggregation. + # + # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') + pivot_columns = pivot.args["columns"] + pivot_aggs_count = len(pivot.expressions) + + pivot_column_mapping = {} + for i, agg in enumerate(pivot.expressions): + agg_cols = list(agg.find_all(exp.Column)) + for col_index in range(i, len(pivot_columns), pivot_aggs_count): + pivot_column_mapping[pivot_columns[col_index].name] = agg_cols + + for c in source_columns: + table = c.table + source = scope.sources.get(table) + + if isinstance(source, Scope): + reference_node_name = None + if ( + source.scope_type == ScopeType.DERIVED_TABLE + and table not in source_names + ): + reference_node_name = table + elif source.scope_type == ScopeType.CTE: + selected_node, _ = scope.selected_sources.get(table, (None, None)) + reference_node_name = selected_node.name if selected_node else None + + # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. + to_node( + c.name, + scope=source, + dialect=dialect, + scope_name=table, + upstream=node, + source_name=source_names.get(table) or source_name, + reference_node_name=reference_node_name, + trim_selects=trim_selects, + ) + elif pivot and pivot.alias_or_name == c.table: + downstream_columns = [] + + column_name = c.name + if any(column_name == pivot_column.name for pivot_column in pivot_columns): + downstream_columns.extend(pivot_column_mapping[column_name]) + else: + # The column is not in the pivot, so it must be an implicit column of the + # pivoted source -- adapt column to be from the implicit pivoted source. + downstream_columns.append( + exp.column(c.this, table=pivot.parent.alias_or_name) + ) + + for downstream_column in downstream_columns: + table = downstream_column.table + source = scope.sources.get(table) + if isinstance(source, Scope): + to_node( + downstream_column.name, + scope=source, + scope_name=table, + dialect=dialect, + upstream=node, + source_name=source_names.get(table) or source_name, + reference_node_name=reference_node_name, + trim_selects=trim_selects, + ) + else: + source = source or exp.Placeholder() + node.downstream.append( + Node( + name=downstream_column.sql(comments=False), + source=source, + expression=source, + ) + ) + else: + # The source is not a scope and the column is not in any pivot - we've reached the end + # of the line. At this point, if a source is not found it means this column's lineage + # is unknown. This can happen if the definition of a source used in a query is not + # passed into the `sources` map. + source = source or exp.Placeholder() + node.downstream.append( + Node(name=c.sql(comments=False), source=source, expression=source) + ) + + return node + + +class GraphHTML: + """Node to HTML generator using vis.js. + + https://visjs.github.io/vis-network/docs/network/ + """ + + def __init__( + self, + nodes: t.Dict, + edges: t.List, + imports: bool = True, + options: t.Optional[t.Dict] = None, + ): + self.imports = imports + + self.options = { + "height": "500px", + "width": "100%", + "layout": { + "hierarchical": { + "enabled": True, + "nodeSpacing": 200, + "sortMethod": "directed", + }, + }, + "interaction": { + "dragNodes": False, + "selectable": False, + }, + "physics": { + "enabled": False, + }, + "edges": { + "arrows": "to", + }, + "nodes": { + "font": "20px monaco", + "shape": "box", + "widthConstraint": { + "maximum": 300, + }, + }, + **(options or {}), + } + + self.nodes = nodes + self.edges = edges + + def __str__(self): + nodes = json.dumps(list(self.nodes.values())) + edges = json.dumps(self.edges) + options = json.dumps(self.options) + imports = ( + """ + + """ + if self.imports + else "" + ) + + return f"""
+
+ {imports} + +
""" + + def _repr_html_(self) -> str: + return self.__str__() diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/__init__.py b/third_party/bigframes_vendored/sqlglot/optimizer/__init__.py new file mode 100644 index 0000000000..4f4a6e756e --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/__init__.py @@ -0,0 +1,22 @@ +# ruff: noqa: F401 + +from bigframes_vendored.sqlglot.optimizer.optimizer import ( # noqa: F401 + optimize as optimize, +) +from bigframes_vendored.sqlglot.optimizer.optimizer import RULES as RULES # noqa: F401 +from bigframes_vendored.sqlglot.optimizer.scope import ( # noqa: F401 + build_scope as build_scope, +) +from bigframes_vendored.sqlglot.optimizer.scope import ( # noqa: F401 + find_all_in_scope as find_all_in_scope, +) +from bigframes_vendored.sqlglot.optimizer.scope import ( # noqa: F401 + find_in_scope as find_in_scope, +) +from bigframes_vendored.sqlglot.optimizer.scope import Scope as Scope # noqa: F401 +from bigframes_vendored.sqlglot.optimizer.scope import ( # noqa: F401 + traverse_scope as traverse_scope, +) +from bigframes_vendored.sqlglot.optimizer.scope import ( # noqa: F401 + walk_in_scope as walk_in_scope, +) diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/annotate_types.py b/third_party/bigframes_vendored/sqlglot/optimizer/annotate_types.py new file mode 100644 index 0000000000..856e4065ee --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/annotate_types.py @@ -0,0 +1,893 @@ +from __future__ import annotations + +import functools +import logging +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect +from bigframes_vendored.sqlglot.helper import ( + ensure_list, + is_date_unit, + is_iso_date, + is_iso_datetime, + seq_get, +) +from bigframes_vendored.sqlglot.optimizer.scope import Scope, traverse_scope +from bigframes_vendored.sqlglot.schema import ensure_schema, MappingSchema, Schema + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import B, E + + BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] + BinaryCoercions = t.Dict[ + t.Tuple[exp.DataType.Type, exp.DataType.Type], + BinaryCoercionFunc, + ] + + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + from bigframes_vendored.sqlglot.typing import ExpressionMetadataType + +logger = logging.getLogger("sqlglot") + + +def annotate_types( + expression: E, + schema: t.Optional[t.Dict | Schema] = None, + expression_metadata: t.Optional[ExpressionMetadataType] = None, + coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, + dialect: DialectType = None, + overwrite_types: bool = True, +) -> E: + """ + Infers the types of an expression, annotating its AST accordingly. + + Example: + >>> import sqlglot + >>> schema = {"y": {"cola": "SMALLINT"}} + >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" + >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) + >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" + + + Args: + expression: Expression to annotate. + schema: Database schema. + expression_metadata: Maps expression type to corresponding annotation function. + coerces_to: Maps expression type to set of types that it can be coerced into. + overwrite_types: Re-annotate the existing AST types. + + Returns: + The expression annotated with types. + """ + + schema = ensure_schema(schema, dialect=dialect) + + return TypeAnnotator( + schema=schema, + expression_metadata=expression_metadata, + coerces_to=coerces_to, + overwrite_types=overwrite_types, + ).annotate(expression) + + +def _coerce_date_literal( + l: exp.Expression, unit: t.Optional[exp.Expression] +) -> exp.DataType.Type: + date_text = l.name + is_iso_date_ = is_iso_date(date_text) + + if is_iso_date_ and is_date_unit(unit): + return exp.DataType.Type.DATE + + # An ISO date is also an ISO datetime, but not vice versa + if is_iso_date_ or is_iso_datetime(date_text): + return exp.DataType.Type.DATETIME + + return exp.DataType.Type.UNKNOWN + + +def _coerce_date( + l: exp.Expression, unit: t.Optional[exp.Expression] +) -> exp.DataType.Type: + if not is_date_unit(unit): + return exp.DataType.Type.DATETIME + return l.type.this if l.type else exp.DataType.Type.UNKNOWN + + +def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: + @functools.wraps(func) + def _swapped(ll: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + return func(r, ll) + + return _swapped + + +def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: + return { + **coercions, + **{(b, a): swap_args(func) for (a, b), func in coercions.items()}, + } + + +class _TypeAnnotator(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): + # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html + text_precedence = ( + exp.DataType.Type.TEXT, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NCHAR, + exp.DataType.Type.CHAR, + ) + numeric_precedence = ( + exp.DataType.Type.DECFLOAT, + exp.DataType.Type.DOUBLE, + exp.DataType.Type.FLOAT, + exp.DataType.Type.BIGDECIMAL, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.BIGINT, + exp.DataType.Type.INT, + exp.DataType.Type.SMALLINT, + exp.DataType.Type.TINYINT, + ) + timelike_precedence = ( + exp.DataType.Type.TIMESTAMPLTZ, + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.DATETIME, + exp.DataType.Type.DATE, + ) + + for type_precedence in ( + text_precedence, + numeric_precedence, + timelike_precedence, + ): + coerces_to = set() + for data_type in type_precedence: + klass.COERCES_TO[data_type] = coerces_to.copy() + coerces_to |= {data_type} + return klass + + +class TypeAnnotator(metaclass=_TypeAnnotator): + NESTED_TYPES = { + exp.DataType.Type.ARRAY, + } + + # Specifies what types a given type can be coerced into (autofilled) + COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} + + # Coercion functions for binary operations. + # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. + BINARY_COERCIONS: BinaryCoercions = { + **swap_all( + { + (t, exp.DataType.Type.INTERVAL): lambda ll, r: _coerce_date_literal( + ll, r.args.get("unit") + ) + for t in exp.DataType.TEXT_TYPES + } + ), + **swap_all( + { + # text + numeric will yield the numeric type to match most dialects' semantics + (text, numeric): lambda ll, r: t.cast( + exp.DataType.Type, + ll.type if ll.type in exp.DataType.NUMERIC_TYPES else r.type, + ) + for text in exp.DataType.TEXT_TYPES + for numeric in exp.DataType.NUMERIC_TYPES + } + ), + **swap_all( + { + ( + exp.DataType.Type.DATE, + exp.DataType.Type.INTERVAL, + ): lambda ll, r: _coerce_date(ll, r.args.get("unit")), + } + ), + } + + def __init__( + self, + schema: Schema, + expression_metadata: t.Optional[ExpressionMetadataType] = None, + coerces_to: t.Optional[ + t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] + ] = None, + binary_coercions: t.Optional[BinaryCoercions] = None, + overwrite_types: bool = True, + ) -> None: + self.schema = schema + dialect = schema.dialect or Dialect() + self.dialect = dialect + self.expression_metadata = expression_metadata or dialect.EXPRESSION_METADATA + self.coerces_to = coerces_to or dialect.COERCES_TO or self.COERCES_TO + self.binary_coercions = binary_coercions or self.BINARY_COERCIONS + + # Caches the ids of annotated sub-Expressions, to ensure we only visit them once + self._visited: t.Set[int] = set() + + # Caches NULL-annotated expressions to set them to UNKNOWN after type inference is completed + self._null_expressions: t.Dict[int, exp.Expression] = {} + + # Databricks and Spark ≥v3 actually support NULL (i.e., VOID) as a type + self._supports_null_type = dialect.SUPPORTS_NULL_TYPE + + # Maps an exp.SetOperation's id (e.g. UNION) to its projection types. This is computed if the + # exp.SetOperation is the expression of a scope source, as selecting from it multiple times + # would reprocess the entire subtree to coerce the types of its operands' projections + self._setop_column_types: t.Dict[ + int, t.Dict[str, exp.DataType | exp.DataType.Type] + ] = {} + + # When set to False, this enables partial annotation by skipping already-annotated nodes + self._overwrite_types = overwrite_types + + def clear(self) -> None: + self._visited.clear() + self._null_expressions.clear() + self._setop_column_types.clear() + + def _set_type( + self, expression: E, target_type: t.Optional[exp.DataType | exp.DataType.Type] + ) -> E: + prev_type = expression.type + expression_id = id(expression) + + expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore + self._visited.add(expression_id) + + if ( + not self._supports_null_type + and t.cast(exp.DataType, expression.type).this == exp.DataType.Type.NULL + ): + self._null_expressions[expression_id] = expression + elif ( + prev_type and t.cast(exp.DataType, prev_type).this == exp.DataType.Type.NULL + ): + self._null_expressions.pop(expression_id, None) + + if ( + isinstance(expression, exp.Column) + and expression.is_type(exp.DataType.Type.JSON) + and (dot_parts := expression.meta.get("dot_parts")) + ): + # JSON dot access is case sensitive across all dialects, so we need to undo the normalization. + i = iter(dot_parts) + parent = expression.parent + while isinstance(parent, exp.Dot): + parent.expression.set("this", exp.to_identifier(next(i), quoted=True)) + parent = parent.parent + + expression.meta.pop("dot_parts", None) + + return expression + + def annotate(self, expression: E, annotate_scope: bool = True) -> E: + # This flag is used to avoid costly scope traversals when we only care about annotating + # non-column expressions (partial type inference), e.g., when simplifying in the optimizer + if annotate_scope: + for scope in traverse_scope(expression): + self.annotate_scope(scope) + + # This takes care of non-traversable expressions + self._annotate_expression(expression) + + # Replace NULL type with the default type of the targeted dialect, since the former is not an actual type; + # it is mostly used to aid type coercion, e.g. in query set operations. + for expr in self._null_expressions.values(): + expr.type = self.dialect.DEFAULT_NULL_TYPE + + return expression + + def annotate_scope(self, scope: Scope) -> None: + selects = {} + + for name, source in scope.sources.items(): + if not isinstance(source, Scope): + continue + + expression = source.expression + if isinstance(expression, exp.UDTF): + values = [] + + if isinstance(expression, exp.Lateral): + if isinstance(expression.this, exp.Explode): + values = [expression.this.this] + elif isinstance(expression, exp.Unnest): + values = [expression] + elif not isinstance(expression, exp.TableFromRows): + values = expression.expressions[0].expressions + + if not values: + continue + + alias_column_names = expression.alias_column_names + + if ( + isinstance(expression, exp.Unnest) + and not alias_column_names + and expression.type + and expression.type.is_type(exp.DataType.Type.STRUCT) + ): + selects[name] = { + col_def.name: t.cast( + t.Union[exp.DataType, exp.DataType.Type], col_def.kind + ) + for col_def in expression.type.expressions + if isinstance(col_def, exp.ColumnDef) and col_def.kind + } + else: + selects[name] = { + alias: column.type + for alias, column in zip(alias_column_names, values) + } + elif isinstance(expression, exp.SetOperation) and len( + expression.left.selects + ) == len(expression.right.selects): + selects[name] = self._get_setop_column_types(expression) + + else: + selects[name] = {s.alias_or_name: s.type for s in expression.selects} + + if isinstance(self.schema, MappingSchema): + for table_column in scope.table_columns: + source = scope.sources.get(table_column.name) + + if isinstance(source, exp.Table): + schema = self.schema.find( + source, raise_on_missing=False, ensure_data_types=True + ) + if not isinstance(schema, dict): + continue + + struct_type = exp.DataType( + this=exp.DataType.Type.STRUCT, + expressions=[ + exp.ColumnDef(this=exp.to_identifier(c), kind=kind) + for c, kind in schema.items() + ], + nested=True, + ) + self._set_type(table_column, struct_type) + elif ( + isinstance(source, Scope) + and isinstance(source.expression, exp.Query) + and ( + source.expression.meta.get("query_type") + or exp.DataType.build("UNKNOWN") + ).is_type(exp.DataType.Type.STRUCT) + ): + self._set_type(table_column, source.expression.meta["query_type"]) + + # Iterate through all the expressions of the current scope in post-order, and annotate + self._annotate_expression(scope.expression, scope, selects) + + if self.dialect.QUERY_RESULTS_ARE_STRUCTS and isinstance( + scope.expression, exp.Query + ): + struct_type = exp.DataType( + this=exp.DataType.Type.STRUCT, + expressions=[ + exp.ColumnDef( + this=exp.to_identifier(select.output_name), + kind=select.type.copy() if select.type else None, + ) + for select in scope.expression.selects + ], + nested=True, + ) + + if not any( + cd.kind.is_type(exp.DataType.Type.UNKNOWN) + for cd in struct_type.expressions + if cd.kind + ): + # We don't use `_set_type` on purpose here. If we annotated the query directly, then + # using it in other contexts (e.g., ARRAY()) could result in incorrect type + # annotations, i.e., it shouldn't be interpreted as a STRUCT value. + scope.expression.meta["query_type"] = struct_type + + def _annotate_expression( + self, + expression: exp.Expression, + scope: t.Optional[Scope] = None, + selects: t.Optional[t.Dict[str, t.Dict[str, t.Any]]] = None, + ) -> None: + stack = [(expression, False)] + selects = selects or {} + + while stack: + expr, children_annotated = stack.pop() + + if id(expr) in self._visited or ( + not self._overwrite_types + and expr.type + and not expr.is_type(exp.DataType.Type.UNKNOWN) + ): + continue # We've already inferred the expression's type + + if not children_annotated: + stack.append((expr, True)) + for child_expr in expr.iter_expressions(): + stack.append((child_expr, False)) + continue + + if scope and isinstance(expr, exp.Column) and expr.table: + source = scope.sources.get(expr.table) + if isinstance(source, exp.Table): + self._set_type(expr, self.schema.get_column_type(source, expr)) + elif source: + if expr.table in selects and expr.name in selects[expr.table]: + self._set_type(expr, selects[expr.table][expr.name]) + elif isinstance(source.expression, exp.Unnest): + self._set_type(expr, source.expression.type) + else: + self._set_type(expr, exp.DataType.Type.UNKNOWN) + else: + self._set_type(expr, exp.DataType.Type.UNKNOWN) + + if expr.type and expr.type.args.get("nullable") is False: + expr.meta["nonnull"] = True + continue + + spec = self.expression_metadata.get(expr.__class__) + + if spec and (annotator := spec.get("annotator")): + annotator(self, expr) + elif spec and (returns := spec.get("returns")): + self._set_type(expr, t.cast(exp.DataType.Type, returns)) + else: + self._set_type(expr, exp.DataType.Type.UNKNOWN) + + def _maybe_coerce( + self, + type1: exp.DataType | exp.DataType.Type, + type2: exp.DataType | exp.DataType.Type, + ) -> exp.DataType | exp.DataType.Type: + """ + Returns type2 if type1 can be coerced into it, otherwise type1. + + If either type is parameterized (e.g. DECIMAL(18, 2) contains two parameters), + we assume type1 does not coerce into type2, so we also return it in this case. + """ + if isinstance(type1, exp.DataType): + if type1.expressions: + return type1 + type1_value = type1.this + else: + type1_value = type1 + + if isinstance(type2, exp.DataType): + if type2.expressions: + return type2 + type2_value = type2.this + else: + type2_value = type2 + + # We propagate the UNKNOWN type upwards if found + if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): + return exp.DataType.Type.UNKNOWN + + if type1_value == exp.DataType.Type.NULL: + return type2_value + if type2_value == exp.DataType.Type.NULL: + return type1_value + + return ( + type2_value + if type2_value in self.coerces_to.get(type1_value, {}) + else type1_value + ) + + def _get_setop_column_types( + self, setop: exp.SetOperation + ) -> t.Dict[str, exp.DataType | exp.DataType.Type]: + """ + Computes and returns the coerced column types for a SetOperation. + + This handles UNION, INTERSECT, EXCEPT, etc., coercing types across + left and right operands for all projections/columns. + + Args: + setop: The SetOperation expression to analyze + + Returns: + Dictionary mapping column names to their coerced types + """ + setop_id = id(setop) + if setop_id in self._setop_column_types: + return self._setop_column_types[setop_id] + + col_types: t.Dict[str, exp.DataType | exp.DataType.Type] = {} + + # Validate that left and right have same number of projections + if not ( + isinstance(setop, exp.SetOperation) + and setop.left.selects + and setop.right.selects + and len(setop.left.selects) == len(setop.right.selects) + ): + return col_types + + # Process a chain / sub-tree of set operations + for set_op in setop.walk( + prune=lambda n: not isinstance(n, (exp.SetOperation, exp.Subquery)) + ): + if not isinstance(set_op, exp.SetOperation): + continue + + if set_op.args.get("by_name"): + r_type_by_select = { + s.alias_or_name: s.type for s in set_op.right.selects + } + setop_cols = { + s.alias_or_name: self._maybe_coerce( + t.cast(exp.DataType, s.type), + r_type_by_select.get(s.alias_or_name) + or exp.DataType.Type.UNKNOWN, + ) + for s in set_op.left.selects + } + else: + setop_cols = { + ls.alias_or_name: self._maybe_coerce( + t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type) + ) + for ls, rs in zip(set_op.left.selects, set_op.right.selects) + } + + # Coerce intermediate results with the previously registered types, if they exist + for col_name, col_type in setop_cols.items(): + col_types[col_name] = self._maybe_coerce( + col_type, col_types.get(col_name, exp.DataType.Type.NULL) + ) + + self._setop_column_types[setop_id] = col_types + return col_types + + def _annotate_binary(self, expression: B) -> B: + left, right = expression.left, expression.right + if not left or not right: + expression_sql = expression.sql(self.dialect) + logger.warning( + f"Failed to annotate badly formed binary expression: {expression_sql}" + ) + self._set_type(expression, None) + return expression + + left_type, right_type = left.type.this, right.type.this # type: ignore + + if isinstance(expression, (exp.Connector, exp.Predicate)): + self._set_type(expression, exp.DataType.Type.BOOLEAN) + elif (left_type, right_type) in self.binary_coercions: + self._set_type( + expression, self.binary_coercions[(left_type, right_type)](left, right) + ) + else: + self._set_type(expression, self._maybe_coerce(left_type, right_type)) + + if isinstance(expression, exp.Is) or ( + left.meta.get("nonnull") is True and right.meta.get("nonnull") is True + ): + expression.meta["nonnull"] = True + + return expression + + def _annotate_unary(self, expression: E) -> E: + if isinstance(expression, exp.Not): + self._set_type(expression, exp.DataType.Type.BOOLEAN) + else: + self._set_type(expression, expression.this.type) + + if expression.this.meta.get("nonnull") is True: + expression.meta["nonnull"] = True + + return expression + + def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: + if expression.is_string: + self._set_type(expression, exp.DataType.Type.VARCHAR) + elif expression.is_int: + self._set_type(expression, exp.DataType.Type.INT) + else: + self._set_type(expression, exp.DataType.Type.DOUBLE) + + expression.meta["nonnull"] = True + + return expression + + @t.no_type_check + def _annotate_by_args( + self, + expression: E, + *args: str | exp.Expression, + promote: bool = False, + array: bool = False, + ) -> E: + literal_type = None + non_literal_type = None + nested_type = None + + for arg in args: + if isinstance(arg, str): + expressions = expression.args.get(arg) + else: + expressions = arg + + for expr in ensure_list(expressions): + expr_type = expr.type + + # Stop at the first nested data type found - we don't want to _maybe_coerce nested types + if expr_type.args.get("nested"): + nested_type = expr_type + break + + if not expr_type.is_type(exp.DataType.Type.UNKNOWN): + if isinstance(expr, exp.Literal): + literal_type = self._maybe_coerce( + literal_type or expr_type, expr_type + ) + else: + non_literal_type = self._maybe_coerce( + non_literal_type or expr_type, expr_type + ) + + if nested_type: + break + + result_type = None + + if nested_type: + result_type = nested_type + elif literal_type and non_literal_type: + if self.dialect.PRIORITIZE_NON_LITERAL_TYPES: + literal_this_type = ( + literal_type.this + if isinstance(literal_type, exp.DataType) + else literal_type + ) + non_literal_this_type = ( + non_literal_type.this + if isinstance(non_literal_type, exp.DataType) + else non_literal_type + ) + if ( + literal_this_type in exp.DataType.INTEGER_TYPES + and non_literal_this_type in exp.DataType.INTEGER_TYPES + ) or ( + literal_this_type in exp.DataType.REAL_TYPES + and non_literal_this_type in exp.DataType.REAL_TYPES + ): + result_type = non_literal_type + else: + result_type = literal_type or non_literal_type or exp.DataType.Type.UNKNOWN + + self._set_type( + expression, + result_type or self._maybe_coerce(non_literal_type, literal_type), + ) + + if promote: + if expression.type.this in exp.DataType.INTEGER_TYPES: + self._set_type(expression, exp.DataType.Type.BIGINT) + elif expression.type.this in exp.DataType.FLOAT_TYPES: + self._set_type(expression, exp.DataType.Type.DOUBLE) + + if array: + self._set_type( + expression, + exp.DataType( + this=exp.DataType.Type.ARRAY, + expressions=[expression.type], + nested=True, + ), + ) + + return expression + + def _annotate_timeunit( + self, expression: exp.TimeUnit | exp.DateTrunc + ) -> exp.TimeUnit | exp.DateTrunc: + if expression.this.type.this in exp.DataType.TEXT_TYPES: + datatype = _coerce_date_literal(expression.this, expression.unit) + elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: + datatype = _coerce_date(expression.this, expression.unit) + else: + datatype = exp.DataType.Type.UNKNOWN + + self._set_type(expression, datatype) + return expression + + def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: + bracket_arg = expression.expressions[0] + this = expression.this + + if isinstance(bracket_arg, exp.Slice): + self._set_type(expression, this.type) + elif this.type.is_type(exp.DataType.Type.ARRAY): + self._set_type(expression, seq_get(this.type.expressions, 0)) + elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: + index = this.keys.index(bracket_arg) + value = seq_get(this.values, index) + self._set_type(expression, value.type if value else None) + else: + self._set_type(expression, exp.DataType.Type.UNKNOWN) + + return expression + + def _annotate_div(self, expression: exp.Div) -> exp.Div: + left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore + + if ( + expression.args.get("typed") + and left_type in exp.DataType.INTEGER_TYPES + and right_type in exp.DataType.INTEGER_TYPES + ): + self._set_type(expression, exp.DataType.Type.BIGINT) + else: + self._set_type(expression, self._maybe_coerce(left_type, right_type)) + if expression.type and expression.type.this not in exp.DataType.REAL_TYPES: + self._set_type( + expression, + self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE), + ) + + return expression + + def _annotate_dot(self, expression: exp.Dot) -> exp.Dot: + self._set_type(expression, None) + this_type = expression.this.type + + if this_type and this_type.is_type(exp.DataType.Type.STRUCT): + for e in this_type.expressions: + if e.name == expression.expression.name: + self._set_type(expression, e.kind) + break + + return expression + + def _annotate_explode(self, expression: exp.Explode) -> exp.Explode: + self._set_type(expression, seq_get(expression.this.type.expressions, 0)) + return expression + + def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest: + child = seq_get(expression.expressions, 0) + + if child and child.is_type(exp.DataType.Type.ARRAY): + expr_type = seq_get(child.type.expressions, 0) + else: + expr_type = None + + self._set_type(expression, expr_type) + return expression + + def _annotate_subquery(self, expression: exp.Subquery) -> exp.Subquery: + # For scalar subqueries (subqueries with a single projection), infer the type + # from that single projection. This allows type propagation in cases like: + # SELECT (SELECT 1 AS c) AS c + query = expression.unnest() + + if isinstance(query, exp.Query): + selects = query.selects + if len(selects) == 1: + self._set_type(expression, selects[0].type) + return expression + + self._set_type(expression, exp.DataType.Type.UNKNOWN) + return expression + + def _annotate_struct_value( + self, expression: exp.Expression + ) -> t.Optional[exp.DataType] | exp.ColumnDef: + # Case: STRUCT(key AS value) + this: t.Optional[exp.Expression] = None + kind = expression.type + + if alias := expression.args.get("alias"): + this = alias.copy() + elif expression.expression: + # Case: STRUCT(key = value) or STRUCT(key := value) + this = expression.this.copy() + kind = expression.expression.type + elif isinstance(expression, exp.Column): + # Case: STRUCT(c) + this = expression.this.copy() + + if kind and kind.is_type(exp.DataType.Type.UNKNOWN): + return None + + if this: + return exp.ColumnDef(this=this, kind=kind) + + return kind + + def _annotate_struct(self, expression: exp.Struct) -> exp.Struct: + expressions = [] + for expr in expression.expressions: + struct_field_type = self._annotate_struct_value(expr) + if struct_field_type is None: + self._set_type(expression, None) + return expression + + expressions.append(struct_field_type) + + self._set_type( + expression, + exp.DataType( + this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True + ), + ) + return expression + + @t.overload + def _annotate_map(self, expression: exp.Map) -> exp.Map: + ... + + @t.overload + def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: + ... + + def _annotate_map(self, expression): + keys = expression.args.get("keys") + values = expression.args.get("values") + + map_type = exp.DataType(this=exp.DataType.Type.MAP) + if isinstance(keys, exp.Array) and isinstance(values, exp.Array): + key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN + value_type = ( + seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN + ) + + if ( + key_type != exp.DataType.Type.UNKNOWN + and value_type != exp.DataType.Type.UNKNOWN + ): + map_type.set("expressions", [key_type, value_type]) + map_type.set("nested", True) + + self._set_type(expression, map_type) + return expression + + def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap: + map_type = exp.DataType(this=exp.DataType.Type.MAP) + arg = expression.this + if arg.is_type(exp.DataType.Type.STRUCT): + for coldef in arg.type.expressions: + kind = coldef.kind + if kind != exp.DataType.Type.UNKNOWN: + map_type.set("expressions", [exp.DataType.build("varchar"), kind]) + map_type.set("nested", True) + break + + self._set_type(expression, map_type) + return expression + + def _annotate_extract(self, expression: exp.Extract) -> exp.Extract: + part = expression.name + if part == "TIME": + self._set_type(expression, exp.DataType.Type.TIME) + elif part == "DATE": + self._set_type(expression, exp.DataType.Type.DATE) + else: + self._set_type(expression, exp.DataType.Type.INT) + return expression + + def _annotate_by_array_element(self, expression: exp.Expression) -> exp.Expression: + array_arg = expression.this + if array_arg.type.is_type(exp.DataType.Type.ARRAY): + element_type = ( + seq_get(array_arg.type.expressions, 0) or exp.DataType.Type.UNKNOWN + ) + self._set_type(expression, element_type) + else: + self._set_type(expression, exp.DataType.Type.UNKNOWN) + + return expression diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/canonicalize.py b/third_party/bigframes_vendored/sqlglot/optimizer/canonicalize.py new file mode 100644 index 0000000000..bb7aa50a27 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/canonicalize.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import itertools +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect, DialectType +from bigframes_vendored.sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime +from bigframes_vendored.sqlglot.optimizer.annotate_types import TypeAnnotator + + +def canonicalize( + expression: exp.Expression, dialect: DialectType = None +) -> exp.Expression: + """Converts a sql expression into a standard form. + + This method relies on annotate_types because many of the + conversions rely on type inference. + + Args: + expression: The expression to canonicalize. + """ + + dialect = Dialect.get_or_raise(dialect) + + def _canonicalize(expression: exp.Expression) -> exp.Expression: + expression = add_text_to_concat(expression) + expression = replace_date_funcs(expression, dialect=dialect) + expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) + expression = remove_redundant_casts(expression) + expression = ensure_bools(expression, _replace_int_predicate) + expression = remove_ascending_order(expression) + return expression + + return exp.replace_tree(expression, _canonicalize) + + +def add_text_to_concat(node: exp.Expression) -> exp.Expression: + if ( + isinstance(node, exp.Add) + and node.type + and node.type.this in exp.DataType.TEXT_TYPES + ): + node = exp.Concat( + expressions=[node.left, node.right], + # All known dialects, i.e. Redshift and T-SQL, that support + # concatenating strings with the + operator do not coalesce NULLs. + coalesce=False, + ) + return node + + +def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression: + if ( + isinstance(node, (exp.Date, exp.TsOrDsToDate)) + and not node.expressions + and not node.args.get("zone") + and node.this.is_string + and is_iso_date(node.this.name) + ): + return exp.cast(node.this, to=exp.DataType.Type.DATE) + if isinstance(node, exp.Timestamp) and not node.args.get("zone"): + if not node.type: + from bigframes_vendored.sqlglot.optimizer.annotate_types import ( + annotate_types, + ) + + node = annotate_types(node, dialect=dialect) + return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) + + return node + + +COERCIBLE_DATE_OPS = ( + exp.Add, + exp.Sub, + exp.EQ, + exp.NEQ, + exp.GT, + exp.GTE, + exp.LT, + exp.LTE, + exp.NullSafeEQ, + exp.NullSafeNEQ, +) + + +def coerce_type( + node: exp.Expression, promote_to_inferred_datetime_type: bool +) -> exp.Expression: + if isinstance(node, COERCIBLE_DATE_OPS): + _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) + elif isinstance(node, exp.Between): + _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) + elif isinstance(node, exp.Extract) and not node.expression.is_type( + *exp.DataType.TEMPORAL_TYPES + ): + _replace_cast(node.expression, exp.DataType.Type.DATETIME) + elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): + _coerce_timeunit_arg(node.this, node.unit) + elif isinstance(node, exp.DateDiff): + _coerce_datediff_args(node) + + return node + + +def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, exp.Cast) + and expression.this.type + and expression.to == expression.this.type + ): + return expression.this + + if ( + isinstance(expression, (exp.Date, exp.TsOrDsToDate)) + and expression.this.type + and expression.this.type.this == exp.DataType.Type.DATE + and not expression.this.type.expressions + ): + return expression.this + + return expression + + +def ensure_bools( + expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] +) -> exp.Expression: + if isinstance(expression, exp.Connector): + replace_func(expression.left) + replace_func(expression.right) + elif isinstance(expression, exp.Not): + replace_func(expression.this) + # We can't replace num in CASE x WHEN num ..., because it's not the full predicate + elif isinstance(expression, exp.If) and not ( + isinstance(expression.parent, exp.Case) and expression.parent.this + ): + replace_func(expression.this) + elif isinstance(expression, (exp.Where, exp.Having)): + replace_func(expression.this) + + return expression + + +def remove_ascending_order(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: + # Convert ORDER BY a ASC to ORDER BY a + expression.set("desc", None) + + return expression + + +def _coerce_date( + a: exp.Expression, + b: exp.Expression, + promote_to_inferred_datetime_type: bool, +) -> None: + for a, b in itertools.permutations([a, b]): + if isinstance(b, exp.Interval): + a = _coerce_timeunit_arg(a, b.unit) + + a_type = a.type + if ( + not a_type + or a_type.this not in exp.DataType.TEMPORAL_TYPES + or not b.type + or b.type.this not in exp.DataType.TEXT_TYPES + ): + continue + + if promote_to_inferred_datetime_type: + if b.is_string: + date_text = b.name + if is_iso_date(date_text): + b_type = exp.DataType.Type.DATE + elif is_iso_datetime(date_text): + b_type = exp.DataType.Type.DATETIME + else: + b_type = a_type.this + else: + # If b is not a datetime string, we conservatively promote it to a DATETIME, + # in order to ensure there are no surprising truncations due to downcasting + b_type = exp.DataType.Type.DATETIME + + target_type = ( + b_type + if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) + else a_type + ) + else: + target_type = a_type + + if target_type != a_type: + _replace_cast(a, target_type) + + _replace_cast(b, target_type) + + +def _coerce_timeunit_arg( + arg: exp.Expression, unit: t.Optional[exp.Expression] +) -> exp.Expression: + if not arg.type: + return arg + + if arg.type.this in exp.DataType.TEXT_TYPES: + date_text = arg.name + is_iso_date_ = is_iso_date(date_text) + + if is_iso_date_ and is_date_unit(unit): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) + + # An ISO date is also an ISO datetime, but not vice versa + if is_iso_date_ or is_iso_datetime(date_text): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) + + elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) + + return arg + + +def _coerce_datediff_args(node: exp.DateDiff) -> None: + for e in (node.this, node.expression): + if e.type.this not in exp.DataType.TEMPORAL_TYPES: + e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) + + +def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None: + node.replace(exp.cast(node.copy(), to=to)) + + +# this was originally designed for presto, there is a similar transform for tsql +# this is different in that it only operates on int types, this is because +# presto has a boolean type whereas tsql doesn't (people use bits) +# with y as (select true as x) select x = 0 FROM y -- illegal presto query +def _replace_int_predicate(expression: exp.Expression) -> None: + if isinstance(expression, exp.Coalesce): + for child in expression.iter_expressions(): + _replace_int_predicate(child) + elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: + expression.replace(expression.neq(0)) diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_ctes.py b/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_ctes.py new file mode 100644 index 0000000000..e22319ad6c --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_ctes.py @@ -0,0 +1,43 @@ +from bigframes_vendored.sqlglot.optimizer.scope import build_scope, Scope + + +def eliminate_ctes(expression): + """ + Remove unused CTEs from an expression. + + Example: + >>> import sqlglot + >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_ctes(expression).sql() + 'SELECT a FROM z' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + root = build_scope(expression) + + if root: + ref_count = root.ref_count() + + # Traverse the scope tree in reverse so we can remove chains of unused CTEs + for scope in reversed(list(root.traverse())): + if scope.is_cte: + count = ref_count[id(scope)] + if count <= 0: + cte_node = scope.expression.parent + with_node = cte_node.parent + cte_node.pop() + + # Pop the entire WITH clause if this is the last CTE + if with_node and len(with_node.expressions) <= 0: + with_node.pop() + + # Decrement the ref count for all sources this CTE selects from + for _, source in scope.selected_sources.values(): + if isinstance(source, Scope): + ref_count[id(source)] -= 1 + + return expression diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_joins.py b/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_joins.py new file mode 100644 index 0000000000..966846da16 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_joins.py @@ -0,0 +1,189 @@ +from bigframes_vendored.sqlglot import expressions as exp +from bigframes_vendored.sqlglot.optimizer.normalize import normalized +from bigframes_vendored.sqlglot.optimizer.scope import Scope, traverse_scope + + +def eliminate_joins(expression): + """ + Remove unused joins from an expression. + + This only removes joins when we know that the join condition doesn't produce duplicate rows. + + Example: + >>> import sqlglot + >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_joins(expression).sql() + 'SELECT x.a FROM x' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for scope in traverse_scope(expression): + # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. + # It's probably possible to infer this from the outputs of derived tables. + # But for now, let's just skip this rule. + if scope.unqualified_columns: + continue + + joins = scope.expression.args.get("joins", []) + + # Reverse the joins so we can remove chains of unused joins + for join in reversed(joins): + if join.is_semi_or_anti_join: + continue + + alias = join.alias_or_name + if _should_eliminate_join(scope, join, alias): + join.pop() + scope.remove_source(alias) + return expression + + +def _should_eliminate_join(scope, join, alias): + inner_source = scope.sources.get(alias) + return ( + isinstance(inner_source, Scope) + and not _join_is_used(scope, join, alias) + and ( + ( + join.side == "LEFT" + and _is_joined_on_all_unique_outputs(inner_source, join) + ) + or (not join.args.get("on") and _has_single_output_row(inner_source)) + ) + ) + + +def _join_is_used(scope, join, alias): + # We need to find all columns that reference this join. + # But columns in the ON clause shouldn't count. + on = join.args.get("on") + if on: + on_clause_columns = {id(column) for column in on.find_all(exp.Column)} + else: + on_clause_columns = set() + return any( + column + for column in scope.source_columns(alias) + if id(column) not in on_clause_columns + ) + + +def _is_joined_on_all_unique_outputs(scope, join): + unique_outputs = _unique_outputs(scope) + if not unique_outputs: + return False + + _, join_keys, _ = join_condition(join) + remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} + return not remaining_unique_outputs + + +def _unique_outputs(scope): + """Determine output columns of `scope` that must have a unique combination per row""" + if scope.expression.args.get("distinct"): + return set(scope.expression.named_selects) + + group = scope.expression.args.get("group") + if group: + grouped_expressions = set(group.expressions) + grouped_outputs = set() + + unique_outputs = set() + for select in scope.expression.selects: + output = select.unalias() + if output in grouped_expressions: + grouped_outputs.add(output) + unique_outputs.add(select.alias_or_name) + + # All the grouped expressions must be in the output + if not grouped_expressions.difference(grouped_outputs): + return unique_outputs + else: + return set() + + if _has_single_output_row(scope): + return set(scope.expression.named_selects) + + return set() + + +def _has_single_output_row(scope): + return isinstance(scope.expression, exp.Select) and ( + all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) + or _is_limit_1(scope) + or not scope.expression.args.get("from_") + ) + + +def _is_limit_1(scope): + limit = scope.expression.args.get("limit") + return limit and limit.expression.this == "1" + + +def join_condition(join): + """ + Extract the join condition from a join expression. + + Args: + join (exp.Join) + Returns: + tuple[list[str], list[str], exp.Expression]: + Tuple of (source key, join key, remaining predicate) + """ + name = join.alias_or_name + on = (join.args.get("on") or exp.true()).copy() + source_key = [] + join_key = [] + + def extract_condition(condition): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.true()) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.true()) + + # find the join keys + # SELECT + # FROM x + # JOIN y + # ON x.a = y.b AND y.b > 1 + # + # should pull y.b as the join key and x.a as the source key + if normalized(on): + on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) + + for condition in on.flatten(): + if isinstance(condition, exp.EQ): + extract_condition(condition) + elif normalized(on, dnf=True): + conditions = None + + for condition in on.flatten(): + parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] + if conditions is None: + conditions = parts + else: + temp = [] + for p in parts: + cs = [c for c in conditions if p == c] + + if cs: + temp.append(p) + temp.extend(cs) + conditions = temp + + for condition in conditions: + extract_condition(condition) + + return source_key, join_key, on diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_subqueries.py b/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_subqueries.py new file mode 100644 index 0000000000..bbc2480e67 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/eliminate_subqueries.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import itertools +import typing as t + +from bigframes_vendored.sqlglot import expressions as exp +from bigframes_vendored.sqlglot.helper import find_new_name +from bigframes_vendored.sqlglot.optimizer.scope import build_scope, Scope + +if t.TYPE_CHECKING: + ExistingCTEsMapping = t.Dict[exp.Expression, str] + TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]] + + +def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: + """ + Rewrite derived tables as CTES, deduplicating if possible. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") + >>> eliminate_subqueries(expression).sql() + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' + + This also deduplicates common subqueries: + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") + >>> eliminate_subqueries(expression).sql() + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' + + Args: + expression (sqlglot.Expression): expression + Returns: + sqlglot.Expression: expression + """ + if isinstance(expression, exp.Subquery): + # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 + eliminate_subqueries(expression.this) + return expression + + root = build_scope(expression) + + if not root: + return expression + + # Map of alias->Scope|Table + # These are all aliases that are already used in the expression. + # We don't want to create new CTEs that conflict with these names. + taken: TakenNameMapping = {} + + # All CTE aliases in the root scope are taken + for scope in root.cte_scopes: + taken[scope.expression.parent.alias] = scope + + # All table names are taken + for scope in root.traverse(): + taken.update( + { + source.name: source + for _, source in scope.sources.items() + if isinstance(source, exp.Table) + } + ) + + # Map of Expression->alias + # Existing CTES in the root expression. We'll use this for deduplication. + existing_ctes: ExistingCTEsMapping = {} + + with_ = root.expression.args.get("with_") + recursive = False + if with_: + recursive = with_.args.get("recursive") + for cte in with_.expressions: + existing_ctes[cte.this] = cte.alias + new_ctes = [] + + # We're adding more CTEs, but we want to maintain the DAG order. + # Derived tables within an existing CTE need to come before the existing CTE. + for cte_scope in root.cte_scopes: + # Append all the new CTEs from this existing CTE + for scope in cte_scope.traverse(): + if scope is cte_scope: + # Don't try to eliminate this CTE itself + continue + new_cte = _eliminate(scope, existing_ctes, taken) + if new_cte: + new_ctes.append(new_cte) + + # Append the existing CTE itself + new_ctes.append(cte_scope.expression.parent) + + # Now append the rest + for scope in itertools.chain( + root.union_scopes, root.subquery_scopes, root.table_scopes + ): + for child_scope in scope.traverse(): + new_cte = _eliminate(child_scope, existing_ctes, taken) + if new_cte: + new_ctes.append(new_cte) + + if new_ctes: + query = expression.expression if isinstance(expression, exp.DDL) else expression + query.set("with_", exp.With(expressions=new_ctes, recursive=recursive)) + + return expression + + +def _eliminate( + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping +) -> t.Optional[exp.Expression]: + if scope.is_derived_table: + return _eliminate_derived_table(scope, existing_ctes, taken) + + if scope.is_cte: + return _eliminate_cte(scope, existing_ctes, taken) + + return None + + +def _eliminate_derived_table( + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping +) -> t.Optional[exp.Expression]: + # This makes sure that we don't: + # - drop the "pivot" arg from a pivoted subquery + # - eliminate a lateral correlated subquery + if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral): + return None + + # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers + to_replace = scope.expression.parent.unwrap() + name, cte = _new_cte(scope, existing_ctes, taken) + table = exp.alias_(exp.table_(name), alias=to_replace.alias or name) + table.set("joins", to_replace.args.get("joins")) + + to_replace.replace(table) + + return cte + + +def _eliminate_cte( + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping +) -> t.Optional[exp.Expression]: + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + with_ = parent.parent + parent.pop() + if not with_.expressions: + with_.pop() + + # Rename references to this CTE + for child_scope in scope.parent.traverse(): + for table, source in child_scope.selected_sources.values(): + if source is scope: + new_table = exp.alias_( + exp.table_(name), alias=table.alias_or_name, copy=False + ) + table.replace(new_table) + + return cte + + +def _new_cte( + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping +) -> t.Tuple[str, t.Optional[exp.Expression]]: + """ + Returns: + tuple of (name, cte) + where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. + If this CTE duplicates an existing CTE, `cte` will be None. + """ + duplicate_cte_alias = existing_ctes.get(scope.expression) + parent = scope.expression.parent + name = parent.alias + + if not name: + name = find_new_name(taken=taken, base="cte") + + if duplicate_cte_alias: + name = duplicate_cte_alias + elif taken.get(name): + name = find_new_name(taken=taken, base=name) + + taken[name] = scope + + if not duplicate_cte_alias: + existing_ctes[scope.expression] = name + cte = exp.CTE( + this=scope.expression, + alias=exp.TableAlias(this=exp.to_identifier(name)), + ) + else: + cte = None + return name, cte diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/isolate_table_selects.py b/third_party/bigframes_vendored/sqlglot/optimizer/isolate_table_selects.py new file mode 100644 index 0000000000..d3742936d9 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/isolate_table_selects.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import alias, exp +from bigframes_vendored.sqlglot.errors import OptimizeError +from bigframes_vendored.sqlglot.optimizer.scope import traverse_scope +from bigframes_vendored.sqlglot.schema import ensure_schema + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + from bigframes_vendored.sqlglot.schema import Schema + + +def isolate_table_selects( + expression: E, + schema: t.Optional[t.Dict | Schema] = None, + dialect: DialectType = None, +) -> E: + schema = ensure_schema(schema, dialect=dialect) + + for scope in traverse_scope(expression): + if len(scope.selected_sources) == 1: + continue + + for _, source in scope.selected_sources.values(): + assert source.parent + + if ( + not isinstance(source, exp.Table) + or not schema.column_names(source) + or isinstance(source.parent, exp.Subquery) + or isinstance(source.parent.parent, exp.Table) + ): + continue + + if not source.alias: + raise OptimizeError( + "Tables require an alias. Run qualify_tables optimization." + ) + + source.replace( + exp.select("*") + .from_( + alias(source, source.alias_or_name, table=True), + copy=False, + ) + .subquery(source.alias, copy=False) + ) + + return expression diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/merge_subqueries.py b/third_party/bigframes_vendored/sqlglot/optimizer/merge_subqueries.py new file mode 100644 index 0000000000..847368a9ff --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/merge_subqueries.py @@ -0,0 +1,444 @@ +from __future__ import annotations + +from collections import defaultdict +import typing as t + +from bigframes_vendored.sqlglot import expressions as exp +from bigframes_vendored.sqlglot.helper import find_new_name, seq_get +from bigframes_vendored.sqlglot.optimizer.scope import Scope, traverse_scope + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + + FromOrJoin = t.Union[exp.From, exp.Join] + + +def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E: + """ + Rewrite sqlglot AST to merge derived tables into the outer query. + + This also merges CTEs if they are selected from only once. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") + >>> merge_subqueries(expression).sql() + 'SELECT x.a FROM x CROSS JOIN y' + + If `leave_tables_isolated` is True, this will not merge inner queries into outer + queries if it would result in multiple table selects in a single query: + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") + >>> merge_subqueries(expression, leave_tables_isolated=True).sql() + 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y' + + Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html + + Args: + expression (sqlglot.Expression): expression to optimize + leave_tables_isolated (bool): + Returns: + sqlglot.Expression: optimized expression + """ + expression = merge_ctes(expression, leave_tables_isolated) + expression = merge_derived_tables(expression, leave_tables_isolated) + return expression + + +# If a derived table has these Select args, it can't be merged +UNMERGABLE_ARGS = set(exp.Select.arg_types) - { + "expressions", + "from_", + "joins", + "where", + "order", + "hint", +} + + +# Projections in the outer query that are instances of these types can be replaced +# without getting wrapped in parentheses, because the precedence won't be altered. +SAFE_TO_REPLACE_UNWRAPPED = ( + exp.Column, + exp.EQ, + exp.Func, + exp.NEQ, + exp.Paren, +) + + +def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E: + scopes = traverse_scope(expression) + + # All places where we select from CTEs. + # We key on the CTE scope so we can detect CTES that are selected from multiple times. + cte_selections = defaultdict(list) + for outer_scope in scopes: + for table, inner_scope in outer_scope.selected_sources.values(): + if isinstance(inner_scope, Scope) and inner_scope.is_cte: + cte_selections[id(inner_scope)].append( + ( + outer_scope, + inner_scope, + table, + ) + ) + + singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] + for outer_scope, inner_scope, table in singular_cte_selections: + from_or_join = table.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): + alias = table.alias_or_name + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, table, alias) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_order(outer_scope, inner_scope) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_hints(outer_scope, inner_scope) + _pop_cte(inner_scope) + outer_scope.clear_cache() + return expression + + +def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E: + for outer_scope in traverse_scope(expression): + for subquery in outer_scope.derived_tables: + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + alias = subquery.alias_or_name + inner_scope = outer_scope.sources[alias] + if _mergeable( + outer_scope, inner_scope, leave_tables_isolated, from_or_join + ): + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, subquery, alias) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_order(outer_scope, inner_scope) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_hints(outer_scope, inner_scope) + outer_scope.clear_cache() + + return expression + + +def _mergeable( + outer_scope: Scope, + inner_scope: Scope, + leave_tables_isolated: bool, + from_or_join: FromOrJoin, +) -> bool: + """ + Return True if `inner_select` can be merged into outer query. + """ + inner_select = inner_scope.expression.unnest() + + def _is_a_window_expression_in_unmergable_operation(): + window_aliases = { + s.alias_or_name for s in inner_select.selects if s.find(exp.Window) + } + inner_select_name = from_or_join.alias_or_name + unmergable_window_columns = [ + column + for column in outer_scope.columns + if column.find_ancestor( + exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc + ) + ] + window_expressions_in_unmergable = [ + column + for column in unmergable_window_columns + if column.table == inner_select_name and column.name in window_aliases + ] + return any(window_expressions_in_unmergable) + + def _outer_select_joins_on_inner_select_join(): + """ + All columns from the inner select in the ON clause must be from the first FROM table. + + That is, this can be merged: + SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + But this can't: + SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + """ + if not isinstance(from_or_join, exp.Join): + return False + + alias = from_or_join.alias_or_name + + on = from_or_join.args.get("on") + if not on: + return False + selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] + inner_from = inner_scope.expression.args.get("from_") + if not inner_from: + return False + inner_from_table = inner_from.alias_or_name + inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects} + return any( + col.table != inner_from_table + for selection in selections + for col in inner_projections[selection].find_all(exp.Column) + ) + + def _is_recursive(): + # Recursive CTEs look like this: + # WITH RECURSIVE cte AS ( + # SELECT * FROM x <-- inner scope + # UNION ALL + # SELECT * FROM cte <-- outer scope + # ) + cte = inner_scope.expression.parent + node = outer_scope.expression.parent + + while node: + if node is cte: + return True + node = node.parent + return False + + return ( + isinstance(outer_scope.expression, exp.Select) + and not outer_scope.expression.is_star + and isinstance(inner_select, exp.Select) + and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) + and inner_select.args.get("from_") is not None + and not outer_scope.pivots + and not any( + e.find(exp.AggFunc, exp.Select, exp.Explode) + for e in inner_select.expressions + ) + and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) + and not (isinstance(from_or_join, exp.Join) and inner_select.args.get("joins")) + and not ( + isinstance(from_or_join, exp.Join) + and inner_select.args.get("where") + and from_or_join.side in ("FULL", "LEFT", "RIGHT") + ) + and not ( + isinstance(from_or_join, exp.From) + and inner_select.args.get("where") + and any( + j.side in ("FULL", "RIGHT") + for j in outer_scope.expression.args.get("joins", []) + ) + ) + and not _outer_select_joins_on_inner_select_join() + and not _is_a_window_expression_in_unmergable_operation() + and not _is_recursive() + and not (inner_select.args.get("order") and outer_scope.is_union) + and not isinstance(seq_get(inner_select.expressions, 0), exp.QueryTransform) + ) + + +def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None: + """ + Renames any sources in the inner query that conflict with names in the outer query. + """ + inner_taken = set(inner_scope.selected_sources) + outer_taken = set(outer_scope.selected_sources) + conflicts = outer_taken.intersection(inner_taken) + conflicts -= {alias} + + taken = outer_taken.union(inner_taken) + + for conflict in conflicts: + new_name = find_new_name(taken, conflict) + + source, _ = inner_scope.selected_sources[conflict] + new_alias = exp.to_identifier(new_name) + + if isinstance(source, exp.Table) and source.alias: + source.set("alias", new_alias) + elif isinstance(source, exp.Table): + source.replace(exp.alias_(source, new_alias)) + elif isinstance(source.parent, exp.Subquery): + source.parent.set("alias", exp.TableAlias(this=new_alias)) + + for column in inner_scope.source_columns(conflict): + column.set("table", exp.to_identifier(new_name)) + + inner_scope.rename_source(conflict, new_name) + + +def _merge_from( + outer_scope: Scope, + inner_scope: Scope, + node_to_replace: t.Union[exp.Subquery, exp.Table], + alias: str, +) -> None: + """ + Merge FROM clause of inner query into outer query. + """ + new_subquery = inner_scope.expression.args["from_"].this + new_subquery.set("joins", node_to_replace.args.get("joins")) + node_to_replace.replace(new_subquery) + for join_hint in outer_scope.join_hints: + tables = join_hint.find_all(exp.Table) + for table in tables: + if table.alias_or_name == node_to_replace.alias_or_name: + table.set("this", exp.to_identifier(new_subquery.alias_or_name)) + outer_scope.remove_source(alias) + outer_scope.add_source( + new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] + ) + + +def _merge_joins( + outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin +) -> None: + """ + Merge JOIN clauses of inner query into outer query. + """ + + new_joins = [] + + joins = inner_scope.expression.args.get("joins") or [] + + for join in joins: + new_joins.append(join) + outer_scope.add_source( + join.alias_or_name, inner_scope.sources[join.alias_or_name] + ) + + if new_joins: + outer_joins = outer_scope.expression.args.get("joins", []) + + # Maintain the join order + if isinstance(from_or_join, exp.From): + position = 0 + else: + position = outer_joins.index(from_or_join) + 1 + outer_joins[position:position] = new_joins + + outer_scope.expression.set("joins", outer_joins) + + +def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None: + """ + Merge projections of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + # Collect all columns that reference the alias of the inner query + outer_columns = defaultdict(list) + for column in outer_scope.columns: + if column.table == alias: + outer_columns[column.name].append(column) + + # Replace columns with the projection expression in the inner query + for expression in inner_scope.expression.expressions: + projection_name = expression.alias_or_name + if not projection_name: + continue + columns_to_replace = outer_columns.get(projection_name, []) + + expression = expression.unalias() + must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED) + + for column in columns_to_replace: + # Ensures we don't alter the intended operator precedence if there's additional + # context surrounding the outer expression (i.e. it's not a simple projection). + if ( + isinstance(column.parent, (exp.Unary, exp.Binary)) + and must_wrap_expression + ): + expression = exp.paren(expression, copy=False) + + # make sure we do not accidentally change the name of the column + if isinstance(column.parent, exp.Select) and column.name != expression.name: + expression = exp.alias_(expression, column.name) + + column.replace(expression.copy()) + + +def _merge_where( + outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin +) -> None: + """ + Merge WHERE clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + where = inner_scope.expression.args.get("where") + if not where or not where.this: + return + + expression = outer_scope.expression + + if isinstance(from_or_join, exp.Join): + # Merge predicates from an outer join to the ON clause + # if it only has columns that are already joined + from_ = expression.args.get("from_") + sources = {from_.alias_or_name} if from_ else set() + + for join in expression.args["joins"]: + source = join.alias_or_name + sources.add(source) + if source == from_or_join.alias_or_name: + break + + if exp.column_table_names(where.this) <= sources: + from_or_join.on(where.this, copy=False) + from_or_join.set("on", from_or_join.args.get("on")) + return + + expression.where(where.this, copy=False) + + +def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None: + """ + Merge ORDER clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + """ + if ( + any( + outer_scope.expression.args.get(arg) + for arg in ["group", "distinct", "having", "order"] + ) + or len(outer_scope.selected_sources) != 1 + or any( + expression.find(exp.AggFunc) + for expression in outer_scope.expression.expressions + ) + ): + return + + outer_scope.expression.set("order", inner_scope.expression.args.get("order")) + + +def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None: + inner_scope_hint = inner_scope.expression.args.get("hint") + if not inner_scope_hint: + return + outer_scope_hint = outer_scope.expression.args.get("hint") + if outer_scope_hint: + for hint_expression in inner_scope_hint.expressions: + outer_scope_hint.append("expressions", hint_expression) + else: + outer_scope.expression.set("hint", inner_scope_hint) + + +def _pop_cte(inner_scope: Scope) -> None: + """ + Remove CTE from the AST. + + Args: + inner_scope (sqlglot.optimizer.scope.Scope) + """ + cte = inner_scope.expression.parent + with_ = cte.parent + if len(with_.expressions) == 1: + with_.pop() + else: + cte.pop() diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/normalize.py b/third_party/bigframes_vendored/sqlglot/optimizer/normalize.py new file mode 100644 index 0000000000..c0633ac3a5 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/normalize.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import logging + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.errors import OptimizeError +from bigframes_vendored.sqlglot.helper import while_changing +from bigframes_vendored.sqlglot.optimizer.scope import find_all_in_scope +from bigframes_vendored.sqlglot.optimizer.simplify import flatten, Simplifier + +logger = logging.getLogger("sqlglot") + + +def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): + """ + Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("(x AND y) OR z") + >>> normalize(expression, dnf=False).sql() + '(x OR z) AND (y OR z)' + + Args: + expression: expression to normalize + dnf: rewrite in disjunctive normal form instead. + max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion + Returns: + sqlglot.Expression: normalized expression + """ + simplifier = Simplifier(annotate_new_expressions=False) + + for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): + if isinstance(node, exp.Connector): + if normalized(node, dnf=dnf): + continue + root = node is expression + original = node.copy() + + node.transform(simplifier.rewrite_between, copy=False) + distance = normalization_distance(node, dnf=dnf, max_=max_distance) + + if distance > max_distance: + logger.info( + f"Skipping normalization because distance {distance} exceeds max {max_distance}" + ) + return expression + + try: + node = node.replace( + while_changing( + node, + lambda e: distributive_law( + e, dnf, max_distance, simplifier=simplifier + ), + ) + ) + except OptimizeError as e: + logger.info(e) + node.replace(original) + if root: + return original + return expression + + if root: + expression = node + + return expression + + +def normalized(expression: exp.Expression, dnf: bool = False) -> bool: + """ + Checks whether a given expression is in a normal form of interest. + + Example: + >>> from sqlglot import parse_one + >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) + True + >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default + True + >>> normalized(parse_one("a AND (b OR c)"), dnf=True) + False + + Args: + expression: The expression to check if it's normalized. + dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + """ + ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) + return not any( + connector.find_ancestor(ancestor) + for connector in find_all_in_scope(expression, root) + ) + + +def normalization_distance( + expression: exp.Expression, dnf: bool = False, max_: float = float("inf") +) -> int: + """ + The difference in the number of predicates between a given expression and its normalized form. + + This is used as an estimate of the cost of the conversion which is exponential in complexity. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") + >>> normalization_distance(expression) + 4 + + Args: + expression: The expression to compute the normalization distance for. + dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + max_: stop early if count exceeds this. + + Returns: + The normalization distance. + """ + total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1) + + for length in _predicate_lengths(expression, dnf, max_): + total += length + if total > max_: + return total + + return total + + +def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0): + """ + Returns a list of predicate lengths when expanded to normalized form. + + (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). + """ + if depth > max_: + yield depth + return + + expression = expression.unnest() + + if not isinstance(expression, exp.Connector): + yield 1 + return + + depth += 1 + left, right = expression.args.values() + + if isinstance(expression, exp.And if dnf else exp.Or): + for a in _predicate_lengths(left, dnf, max_, depth): + for b in _predicate_lengths(right, dnf, max_, depth): + yield a + b + else: + yield from _predicate_lengths(left, dnf, max_, depth) + yield from _predicate_lengths(right, dnf, max_, depth) + + +def distributive_law(expression, dnf, max_distance, simplifier=None): + """ + x OR (y AND z) -> (x OR y) AND (x OR z) + (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) + """ + if normalized(expression, dnf=dnf): + return expression + + distance = normalization_distance(expression, dnf=dnf, max_=max_distance) + + if distance > max_distance: + raise OptimizeError( + f"Normalization distance {distance} exceeds max {max_distance}" + ) + + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) + to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) + + if isinstance(expression, from_exp): + a, b = expression.unnest_operands() + + from_func = exp.and_ if from_exp == exp.And else exp.or_ + to_func = exp.and_ if to_exp == exp.And else exp.or_ + + simplifier = simplifier or Simplifier(annotate_new_expressions=False) + + if isinstance(a, to_exp) and isinstance(b, to_exp): + if len(tuple(a.find_all(exp.Connector))) > len( + tuple(b.find_all(exp.Connector)) + ): + return _distribute(a, b, from_func, to_func, simplifier) + return _distribute(b, a, from_func, to_func, simplifier) + if isinstance(a, to_exp): + return _distribute(b, a, from_func, to_func, simplifier) + if isinstance(b, to_exp): + return _distribute(a, b, from_func, to_func, simplifier) + + return expression + + +def _distribute(a, b, from_func, to_func, simplifier): + if isinstance(a, exp.Connector): + exp.replace_children( + a, + lambda c: to_func( + simplifier.uniq_sort(flatten(from_func(c, b.left))), + simplifier.uniq_sort(flatten(from_func(c, b.right))), + copy=False, + ), + ) + else: + a = to_func( + simplifier.uniq_sort(flatten(from_func(a, b.left))), + simplifier.uniq_sort(flatten(from_func(a, b.right))), + copy=False, + ) + + return a diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/normalize_identifiers.py b/third_party/bigframes_vendored/sqlglot/optimizer/normalize_identifiers.py new file mode 100644 index 0000000000..7d10b73282 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/normalize_identifiers.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect, DialectType + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + + +@t.overload +def normalize_identifiers( + expression: E, + dialect: DialectType = None, + store_original_column_identifiers: bool = False, +) -> E: + ... + + +@t.overload +def normalize_identifiers( + expression: str, + dialect: DialectType = None, + store_original_column_identifiers: bool = False, +) -> exp.Identifier: + ... + + +def normalize_identifiers( + expression, dialect=None, store_original_column_identifiers=False +): + """ + Normalize identifiers by converting them to either lower or upper case, + ensuring the semantics are preserved in each case (e.g. by respecting + case-sensitivity). + + This transformation reflects how identifiers would be resolved by the engine corresponding + to each SQL dialect, and plays a very important role in the standardization of the AST. + + It's possible to make this a no-op by adding a special comment next to the + identifier of interest: + + SELECT a /* sqlglot.meta case_sensitive */ FROM table + + In this example, the identifier `a` will not be normalized. + + Note: + Some dialects (e.g. DuckDB) treat all identifiers as case-insensitive even + when they're quoted, so in these cases all identifiers are normalized. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') + >>> normalize_identifiers(expression).sql() + 'SELECT bar.a AS a FROM "Foo".bar' + >>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake") + 'FOO' + + Args: + expression: The expression to transform. + dialect: The dialect to use in order to decide how to normalize identifiers. + store_original_column_identifiers: Whether to store the original column identifiers in + the meta data of the expression in case we want to undo the normalization at a later point. + + Returns: + The transformed expression. + """ + dialect = Dialect.get_or_raise(dialect) + + if isinstance(expression, str): + expression = exp.parse_identifier(expression, dialect=dialect) + + for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")): + if not node.meta.get("case_sensitive"): + if store_original_column_identifiers and isinstance(node, exp.Column): + # TODO: This does not handle non-column cases, e.g PARSE_JSON(...).key + parent = node + while parent and isinstance(parent.parent, exp.Dot): + parent = parent.parent + + node.meta["dot_parts"] = [p.name for p in parent.parts] + + dialect.normalize_identifier(node) + + return expression diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/optimize_joins.py b/third_party/bigframes_vendored/sqlglot/optimizer/optimize_joins.py new file mode 100644 index 0000000000..177565deb5 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/optimize_joins.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.helper import tsort + +JOIN_ATTRS = ("on", "side", "kind", "using", "method") + + +def optimize_joins(expression): + """ + Removes cross joins if possible and reorder joins based on predicate dependencies. + + Example: + >>> from sqlglot import parse_one + >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() + 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' + """ + + for select in expression.find_all(exp.Select): + joins = select.args.get("joins", []) + + if not _is_reorderable(joins): + continue + + references = {} + cross_joins = [] + + for join in joins: + tables = other_table_names(join) + + if tables: + for table in tables: + references[table] = references.get(table, []) + [join] + else: + cross_joins.append((join.alias_or_name, join)) + + for name, join in cross_joins: + for dep in references.get(name, []): + on = dep.args["on"] + + if isinstance(on, exp.Connector): + if len(other_table_names(dep)) < 2: + continue + + operator = type(on) + for predicate in on.flatten(): + if name in exp.column_table_names(predicate): + predicate.replace(exp.true()) + predicate = exp._combine( + [join.args.get("on"), predicate], operator, copy=False + ) + join.on(predicate, append=False, copy=False) + + expression = reorder_joins(expression) + expression = normalize(expression) + return expression + + +def reorder_joins(expression): + """ + Reorder joins by topological sort order based on predicate references. + """ + for from_ in expression.find_all(exp.From): + parent = from_.parent + joins = parent.args.get("joins", []) + + if not _is_reorderable(joins): + continue + + joins_by_name = {join.alias_or_name: join for join in joins} + dag = {name: other_table_names(join) for name, join in joins_by_name.items()} + parent.set( + "joins", + [ + joins_by_name[name] + for name in tsort(dag) + if name != from_.alias_or_name and name in joins_by_name + ], + ) + return expression + + +def normalize(expression): + """ + Remove INNER and OUTER from joins as they are optional. + """ + for join in expression.find_all(exp.Join): + if not any(join.args.get(k) for k in JOIN_ATTRS): + join.set("kind", "CROSS") + + if join.kind == "CROSS": + join.set("on", None) + else: + if join.kind in ("INNER", "OUTER"): + join.set("kind", None) + + if not join.args.get("on") and not join.args.get("using"): + join.set("on", exp.true()) + return expression + + +def other_table_names(join: exp.Join) -> t.Set[str]: + on = join.args.get("on") + return exp.column_table_names(on, join.alias_or_name) if on else set() + + +def _is_reorderable(joins: t.List[exp.Join]) -> bool: + """ + Checks if joins can be reordered without changing query semantics. + + Joins with a side (LEFT, RIGHT, FULL) cannot be reordered easily, + the order affects which rows are included in the result. + + Example: + >>> from sqlglot import parse_one, exp + >>> from sqlglot.optimizer.optimize_joins import _is_reorderable + >>> ast = parse_one("SELECT * FROM x JOIN y ON x.id = y.id JOIN z ON y.id = z.id") + >>> _is_reorderable(ast.find(exp.Select).args.get("joins", [])) + True + >>> ast = parse_one("SELECT * FROM x LEFT JOIN y ON x.id = y.id JOIN z ON y.id = z.id") + >>> _is_reorderable(ast.find(exp.Select).args.get("joins", [])) + False + """ + return not any(join.side for join in joins) diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/optimizer.py b/third_party/bigframes_vendored/sqlglot/optimizer/optimizer.py new file mode 100644 index 0000000000..4f425dae68 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/optimizer.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import inspect +import typing as t + +from bigframes_vendored.sqlglot import exp, Schema +from bigframes_vendored.sqlglot.dialects.dialect import DialectType +from bigframes_vendored.sqlglot.optimizer.annotate_types import annotate_types +from bigframes_vendored.sqlglot.optimizer.canonicalize import canonicalize +from bigframes_vendored.sqlglot.optimizer.eliminate_ctes import eliminate_ctes +from bigframes_vendored.sqlglot.optimizer.eliminate_joins import eliminate_joins +from bigframes_vendored.sqlglot.optimizer.eliminate_subqueries import ( + eliminate_subqueries, +) +from bigframes_vendored.sqlglot.optimizer.merge_subqueries import merge_subqueries +from bigframes_vendored.sqlglot.optimizer.normalize import normalize +from bigframes_vendored.sqlglot.optimizer.optimize_joins import optimize_joins +from bigframes_vendored.sqlglot.optimizer.pushdown_predicates import pushdown_predicates +from bigframes_vendored.sqlglot.optimizer.pushdown_projections import ( + pushdown_projections, +) +from bigframes_vendored.sqlglot.optimizer.qualify import qualify +from bigframes_vendored.sqlglot.optimizer.qualify_columns import quote_identifiers +from bigframes_vendored.sqlglot.optimizer.simplify import simplify +from bigframes_vendored.sqlglot.optimizer.unnest_subqueries import unnest_subqueries +from bigframes_vendored.sqlglot.schema import ensure_schema + +RULES = ( + qualify, + pushdown_projections, + normalize, + unnest_subqueries, + pushdown_predicates, + optimize_joins, + eliminate_subqueries, + merge_subqueries, + eliminate_joins, + eliminate_ctes, + quote_identifiers, + annotate_types, + canonicalize, + simplify, +) + + +def optimize( + expression: str | exp.Expression, + schema: t.Optional[dict | Schema] = None, + db: t.Optional[str | exp.Identifier] = None, + catalog: t.Optional[str | exp.Identifier] = None, + dialect: DialectType = None, + rules: t.Sequence[t.Callable] = RULES, + sql: t.Optional[str] = None, + **kwargs, +) -> exp.Expression: + """ + Rewrite a sqlglot AST into an optimized form. + + Args: + expression: expression to optimize + schema: database schema. + This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of + the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + If no schema is provided then the default schema defined at `sqlgot.schema` will be used + db: specify the default database, as might be set by a `USE DATABASE db` statement + catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement + dialect: The dialect to parse the sql string. + rules: sequence of optimizer rules to use. + Many of the rules require tables and columns to be qualified. + Do not remove `qualify` from the sequence of rules unless you know what you're doing! + sql: Original SQL string for error highlighting. If not provided, errors will not include + highlighting. Requires that the expression has position metadata from parsing. + **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. + + Returns: + The optimized expression. + """ + schema = ensure_schema(schema, dialect=dialect) + possible_kwargs = { + "db": db, + "catalog": catalog, + "schema": schema, + "dialect": dialect, + "sql": sql, + "isolate_tables": True, # needed for other optimizations to perform well + "quote_identifiers": False, + **kwargs, + } + + optimized = exp.maybe_parse(expression, dialect=dialect, copy=True) + for rule in rules: + # Find any additional rule parameters, beyond `expression` + rule_params = inspect.getfullargspec(rule).args + rule_kwargs = { + param: possible_kwargs[param] + for param in rule_params + if param in possible_kwargs + } + optimized = rule(optimized, **rule_kwargs) + + return optimized diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/pushdown_predicates.py b/third_party/bigframes_vendored/sqlglot/optimizer/pushdown_predicates.py new file mode 100644 index 0000000000..807aa98f36 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/pushdown_predicates.py @@ -0,0 +1,235 @@ +from bigframes_vendored.sqlglot import Dialect, exp +from bigframes_vendored.sqlglot.optimizer.normalize import normalized +from bigframes_vendored.sqlglot.optimizer.scope import build_scope, find_in_scope +from bigframes_vendored.sqlglot.optimizer.simplify import simplify + + +def pushdown_predicates(expression, dialect=None): + """ + Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS + + Example: + >>> import sqlglot + >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" + >>> expression = sqlglot.parse_one(sql) + >>> pushdown_predicates(expression).sql() + 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + from bigframes_vendored.sqlglot.dialects.athena import Athena + from bigframes_vendored.sqlglot.dialects.presto import Presto + + root = build_scope(expression) + + dialect = Dialect.get_or_raise(dialect) + unnest_requires_cross_join = isinstance(dialect, (Athena, Presto)) + + if root: + scope_ref_count = root.ref_count() + + for scope in reversed(list(root.traverse())): + select = scope.expression + where = select.args.get("where") + if where: + selected_sources = scope.selected_sources + join_index = { + join.alias_or_name: i + for i, join in enumerate(select.args.get("joins") or []) + } + + # a right join can only push down to itself and not the source FROM table + # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression + pushdown_allowed = True + for k, (node, source) in selected_sources.items(): + parent = node.find_ancestor(exp.Join, exp.From) + if isinstance(parent, exp.Join): + if parent.side == "RIGHT": + selected_sources = {k: (node, source)} + break + if isinstance(node, exp.Unnest) and unnest_requires_cross_join: + pushdown_allowed = False + break + + if pushdown_allowed: + pushdown( + where.this, + selected_sources, + scope_ref_count, + dialect, + join_index, + ) + + # joins should only pushdown into itself, not to other joins + # so we limit the selected sources to only itself + for join in select.args.get("joins") or []: + name = join.alias_or_name + if name in scope.selected_sources: + pushdown( + join.args.get("on"), + {name: scope.selected_sources[name]}, + scope_ref_count, + dialect, + ) + + return expression + + +def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): + if not condition: + return + + condition = condition.replace(simplify(condition, dialect=dialect)) + cnf_like = normalized(condition) or not normalized(condition, dnf=True) + + predicates = list( + condition.flatten() + if isinstance(condition, exp.And if cnf_like else exp.Or) + else [condition] + ) + + if cnf_like: + pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) + else: + pushdown_dnf(predicates, sources, scope_ref_count) + + +def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): + """ + If the predicates are in CNF like form, we can simply replace each block in the parent. + """ + join_index = join_index or {} + for predicate in predicates: + for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): + if isinstance(node, exp.Join): + name = node.alias_or_name + predicate_tables = exp.column_table_names(predicate, name) + + # Don't push the predicate if it references tables that appear in later joins + this_index = join_index[name] + if all( + join_index.get(table, -1) < this_index for table in predicate_tables + ): + predicate.replace(exp.true()) + node.on(predicate, copy=False) + break + if isinstance(node, exp.Select): + predicate.replace(exp.true()) + inner_predicate = replace_aliases(node, predicate) + if find_in_scope(inner_predicate, exp.AggFunc): + node.having(inner_predicate, copy=False) + else: + node.where(inner_predicate, copy=False) + + +def pushdown_dnf(predicates, sources, scope_ref_count): + """ + If the predicates are in DNF form, we can only push down conditions that are in all blocks. + Additionally, we can't remove predicates from their original form. + """ + # find all the tables that can be pushdown too + # these are tables that are referenced in all blocks of a DNF + # (a.x AND b.x) OR (a.y AND c.y) + # only table a can be push down + pushdown_tables = set() + + for a in predicates: + a_tables = exp.column_table_names(a) + + for b in predicates: + a_tables &= exp.column_table_names(b) + + pushdown_tables.update(a_tables) + + conditions = {} + + # pushdown all predicates to their respective nodes + for table in sorted(pushdown_tables): + for predicate in predicates: + nodes = nodes_for_predicate(predicate, sources, scope_ref_count) + + if table not in nodes: + continue + + conditions[table] = ( + exp.or_(conditions[table], predicate) + if table in conditions + else predicate + ) + + for name, node in nodes.items(): + if name not in conditions: + continue + + predicate = conditions[name] + + if isinstance(node, exp.Join): + node.on(predicate, copy=False) + elif isinstance(node, exp.Select): + inner_predicate = replace_aliases(node, predicate) + if find_in_scope(inner_predicate, exp.AggFunc): + node.having(inner_predicate, copy=False) + else: + node.where(inner_predicate, copy=False) + + +def nodes_for_predicate(predicate, sources, scope_ref_count): + nodes = {} + tables = exp.column_table_names(predicate) + where_condition = isinstance( + predicate.find_ancestor(exp.Join, exp.Where), exp.Where + ) + + for table in sorted(tables): + node, source = sources.get(table) or (None, None) + + # if the predicate is in a where statement we can try to push it down + # we want to find the root join or from statement + if node and where_condition: + node = node.find_ancestor(exp.Join, exp.From) + + # a node can reference a CTE which should be pushed down + if isinstance(node, exp.From) and not isinstance(source, exp.Table): + with_ = source.parent.expression.args.get("with_") + if with_ and with_.recursive: + return {} + node = source.expression + + if isinstance(node, exp.Join): + if node.side and node.side != "RIGHT": + return {} + nodes[table] = node + elif isinstance(node, exp.Select) and len(tables) == 1: + # We can't push down window expressions + has_window_expression = any( + select for select in node.selects if select.find(exp.Window) + ) + # we can't push down predicates to select statements if they are referenced in + # multiple places. + if ( + not node.args.get("group") + and scope_ref_count[id(source)] < 2 + and not has_window_expression + ): + nodes[table] = node + return nodes + + +def replace_aliases(source, predicate): + aliases = {} + + for select in source.selects: + if isinstance(select, exp.Alias): + aliases[select.alias] = select.this + else: + aliases[select.name] = select + + def _replace_alias(column): + if isinstance(column, exp.Column) and column.name in aliases: + return aliases[column.name].copy() + return column + + return predicate.transform(_replace_alias) diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/pushdown_projections.py b/third_party/bigframes_vendored/sqlglot/optimizer/pushdown_projections.py new file mode 100644 index 0000000000..ac3edbb7c8 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/pushdown_projections.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from collections import defaultdict +import typing as t + +from bigframes_vendored.sqlglot import alias, exp +from bigframes_vendored.sqlglot.errors import OptimizeError +from bigframes_vendored.sqlglot.helper import seq_get +from bigframes_vendored.sqlglot.optimizer.qualify_columns import Resolver +from bigframes_vendored.sqlglot.optimizer.scope import Scope, traverse_scope +from bigframes_vendored.sqlglot.schema import ensure_schema + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + from bigframes_vendored.sqlglot.schema import Schema + +# Sentinel value that means an outer query selecting ALL columns +SELECT_ALL = object() + + +# Selection to use if selection list is empty +def default_selection(is_agg: bool) -> exp.Alias: + return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_") + + +def pushdown_projections( + expression: E, + schema: t.Optional[t.Dict | Schema] = None, + remove_unused_selections: bool = True, + dialect: DialectType = None, +) -> E: + """ + Rewrite sqlglot AST to remove unused columns projections. + + Example: + >>> import sqlglot + >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" + >>> expression = sqlglot.parse_one(sql) + >>> pushdown_projections(expression).sql() + 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' + + Args: + expression (sqlglot.Expression): expression to optimize + remove_unused_selections (bool): remove selects that are unused + Returns: + sqlglot.Expression: optimized expression + """ + # Map of Scope to all columns being selected by outer queries. + schema = ensure_schema(schema, dialect=dialect) + source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {} + referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set) + + # We build the scope tree (which is traversed in DFS postorder), then iterate + # over the result in reverse order. This should ensure that the set of selected + # columns for a particular scope are completely build by the time we get to it. + for scope in reversed(traverse_scope(expression)): + parent_selections = referenced_columns.get(scope, {SELECT_ALL}) + alias_count = source_column_alias_count.get(scope, 0) + + # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. + if scope.expression.args.get("distinct"): + parent_selections = {SELECT_ALL} + + if isinstance(scope.expression, exp.SetOperation): + set_op = scope.expression + if not (set_op.kind or set_op.side): + # Do not optimize this set operation if it's using the BigQuery specific + # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation + left, right = scope.union_scopes + if len(left.expression.selects) != len(right.expression.selects): + scope_sql = scope.expression.sql(dialect=dialect) + raise OptimizeError( + f"Invalid set operation due to column mismatch: {scope_sql}." + ) + + referenced_columns[left] = parent_selections + + if any(select.is_star for select in right.expression.selects): + referenced_columns[right] = parent_selections + elif not any(select.is_star for select in left.expression.selects): + if scope.expression.args.get("by_name"): + referenced_columns[right] = referenced_columns[left] + else: + referenced_columns[right] = { + right.expression.selects[i].alias_or_name + for i, select in enumerate(left.expression.selects) + if SELECT_ALL in parent_selections + or select.alias_or_name in parent_selections + } + + if isinstance(scope.expression, exp.Select): + if remove_unused_selections: + _remove_unused_selections(scope, parent_selections, schema, alias_count) + + if scope.expression.is_star: + continue + + # Group columns by source name + selects = defaultdict(set) + for col in scope.columns: + table_name = col.table + col_name = col.name + selects[table_name].add(col_name) + + # Push the selected columns down to the next scope + for name, (node, source) in scope.selected_sources.items(): + if isinstance(source, Scope): + select = seq_get(source.expression.selects, 0) + + if scope.pivots or isinstance(select, exp.QueryTransform): + columns = {SELECT_ALL} + else: + columns = selects.get(name) or set() + + referenced_columns[source].update(columns) + + column_aliases = node.alias_column_names + if column_aliases: + source_column_alias_count[source] = len(column_aliases) + + return expression + + +def _remove_unused_selections(scope, parent_selections, schema, alias_count): + order = scope.expression.args.get("order") + + if order: + # Assume columns without a qualified table are references to output columns + order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} + else: + order_refs = set() + + new_selections = [] + removed = False + star = False + is_agg = False + + select_all = SELECT_ALL in parent_selections + + for selection in scope.expression.selects: + name = selection.alias_or_name + + if ( + select_all + or name in parent_selections + or name in order_refs + or alias_count > 0 + ): + new_selections.append(selection) + alias_count -= 1 + else: + if selection.is_star: + star = True + removed = True + + if not is_agg and selection.find(exp.AggFunc): + is_agg = True + + if star: + resolver = Resolver(scope, schema) + names = {s.alias_or_name for s in new_selections} + + for name in sorted(parent_selections): + if name not in names: + new_selections.append( + alias( + exp.column(name, table=resolver.get_table(name)), + name, + copy=False, + ) + ) + + # If there are no remaining selections, just select a single constant + if not new_selections: + new_selections.append(default_selection(is_agg)) + + scope.expression.select(*new_selections, append=False, copy=False) + + if removed: + scope.clear_cache() diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/qualify.py b/third_party/bigframes_vendored/sqlglot/optimizer/qualify.py new file mode 100644 index 0000000000..2e77b07160 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/qualify.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect, DialectType +from bigframes_vendored.sqlglot.optimizer.isolate_table_selects import ( + isolate_table_selects, +) +from bigframes_vendored.sqlglot.optimizer.normalize_identifiers import ( + normalize_identifiers, +) +from bigframes_vendored.sqlglot.optimizer.qualify_columns import ( + qualify_columns as qualify_columns_func, +) +from bigframes_vendored.sqlglot.optimizer.qualify_columns import ( + quote_identifiers as quote_identifiers_func, +) +from bigframes_vendored.sqlglot.optimizer.qualify_columns import ( + validate_qualify_columns as validate_qualify_columns_func, +) +from bigframes_vendored.sqlglot.optimizer.qualify_tables import qualify_tables +from bigframes_vendored.sqlglot.schema import ensure_schema, Schema + + +def qualify( + expression: exp.Expression, + dialect: DialectType = None, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + schema: t.Optional[dict | Schema] = None, + expand_alias_refs: bool = True, + expand_stars: bool = True, + infer_schema: t.Optional[bool] = None, + isolate_tables: bool = False, + qualify_columns: bool = True, + allow_partial_qualification: bool = False, + validate_qualify_columns: bool = True, + quote_identifiers: bool = True, + identify: bool = True, + canonicalize_table_aliases: bool = False, + on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None, + sql: t.Optional[str] = None, +) -> exp.Expression: + """ + Rewrite sqlglot AST to have normalized and qualified tables and columns. + + This step is necessary for all further SQLGlot optimizations. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify(expression, schema=schema).sql() + 'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"' + + Args: + expression: Expression to qualify. + db: Default database name for tables. + catalog: Default catalog name for tables. + schema: Schema to infer column names and types. + expand_alias_refs: Whether to expand references to aliases. + expand_stars: Whether to expand star queries. This is a necessary step + for most of the optimizer's rules to work; do not set to False unless you + know what you're doing! + infer_schema: Whether to infer the schema if missing. + isolate_tables: Whether to isolate table selects. + qualify_columns: Whether to qualify columns. + allow_partial_qualification: Whether to allow partial qualification. + validate_qualify_columns: Whether to validate columns. + quote_identifiers: Whether to run the quote_identifiers step. + This step is necessary to ensure correctness for case sensitive queries. + But this flag is provided in case this step is performed at a later time. + identify: If True, quote all identifiers, else only necessary ones. + canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources + instead of preserving table names. + on_qualify: Callback after a table has been qualified. + sql: Original SQL string for error highlighting. If not provided, errors will not include + highlighting. Requires that the expression has position metadata from parsing. + + Returns: + The qualified expression. + """ + schema = ensure_schema(schema, dialect=dialect) + dialect = Dialect.get_or_raise(dialect) + + expression = normalize_identifiers( + expression, + dialect=dialect, + store_original_column_identifiers=True, + ) + expression = qualify_tables( + expression, + db=db, + catalog=catalog, + dialect=dialect, + on_qualify=on_qualify, + canonicalize_table_aliases=canonicalize_table_aliases, + ) + + if isolate_tables: + expression = isolate_table_selects(expression, schema=schema) + + if qualify_columns: + expression = qualify_columns_func( + expression, + schema, + expand_alias_refs=expand_alias_refs, + expand_stars=expand_stars, + infer_schema=infer_schema, + allow_partial_qualification=allow_partial_qualification, + ) + + if quote_identifiers: + expression = quote_identifiers_func( + expression, dialect=dialect, identify=identify + ) + + if validate_qualify_columns: + validate_qualify_columns_func(expression, sql=sql) + + return expression diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/qualify_columns.py b/third_party/bigframes_vendored/sqlglot/optimizer/qualify_columns.py new file mode 100644 index 0000000000..aeaf70f78b --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/qualify_columns.py @@ -0,0 +1,1051 @@ +from __future__ import annotations + +import itertools +import typing as t + +from bigframes_vendored.sqlglot import alias, exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect, DialectType +from bigframes_vendored.sqlglot.errors import highlight_sql, OptimizeError +from bigframes_vendored.sqlglot.helper import seq_get +from bigframes_vendored.sqlglot.optimizer.annotate_types import TypeAnnotator +from bigframes_vendored.sqlglot.optimizer.resolver import Resolver +from bigframes_vendored.sqlglot.optimizer.scope import ( + build_scope, + Scope, + traverse_scope, + walk_in_scope, +) +from bigframes_vendored.sqlglot.optimizer.simplify import simplify_parens +from bigframes_vendored.sqlglot.schema import ensure_schema, Schema + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + + +def qualify_columns( + expression: exp.Expression, + schema: t.Dict | Schema, + expand_alias_refs: bool = True, + expand_stars: bool = True, + infer_schema: t.Optional[bool] = None, + allow_partial_qualification: bool = False, + dialect: DialectType = None, +) -> exp.Expression: + """ + Rewrite sqlglot AST to have fully qualified columns. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify_columns(expression, schema).sql() + 'SELECT tbl.col AS col FROM tbl' + + Args: + expression: Expression to qualify. + schema: Database schema. + expand_alias_refs: Whether to expand references to aliases. + expand_stars: Whether to expand star queries. This is a necessary step + for most of the optimizer's rules to work; do not set to False unless you + know what you're doing! + infer_schema: Whether to infer the schema if missing. + allow_partial_qualification: Whether to allow partial qualification. + + Returns: + The qualified expression. + + Notes: + - Currently only handles a single PIVOT or UNPIVOT operator + """ + schema = ensure_schema(schema, dialect=dialect) + annotator = TypeAnnotator(schema) + infer_schema = schema.empty if infer_schema is None else infer_schema + dialect = schema.dialect or Dialect() + pseudocolumns = dialect.PSEUDOCOLUMNS + + for scope in traverse_scope(expression): + if dialect.PREFER_CTE_ALIAS_COLUMN: + pushdown_cte_alias_columns(scope) + + scope_expression = scope.expression + is_select = isinstance(scope_expression, exp.Select) + + _separate_pseudocolumns(scope, pseudocolumns) + + resolver = Resolver(scope, schema, infer_schema=infer_schema) + _pop_table_column_aliases(scope.ctes) + _pop_table_column_aliases(scope.derived_tables) + using_column_tables = _expand_using(scope, resolver) + + if ( + schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION + ) and expand_alias_refs: + _expand_alias_refs( + scope, + resolver, + dialect, + expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF, + ) + + _convert_columns_to_dots(scope, resolver) + _qualify_columns( + scope, + resolver, + allow_partial_qualification=allow_partial_qualification, + ) + + if not schema.empty and expand_alias_refs: + _expand_alias_refs(scope, resolver, dialect) + + if is_select: + if expand_stars: + _expand_stars( + scope, + resolver, + using_column_tables, + pseudocolumns, + annotator, + ) + qualify_outputs(scope) + + _expand_group_by(scope, dialect) + + # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) + # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT + _expand_order_by_and_distinct_on(scope, resolver) + + if dialect.ANNOTATE_ALL_SCOPES: + annotator.annotate_scope(scope) + + return expression + + +def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E: + """Raise an `OptimizeError` if any columns aren't qualified""" + all_unqualified_columns = [] + for scope in traverse_scope(expression): + if isinstance(scope.expression, exp.Select): + unqualified_columns = scope.unqualified_columns + + if ( + scope.external_columns + and not scope.is_correlated_subquery + and not scope.pivots + ): + column = scope.external_columns[0] + for_table = f" for table: '{column.table}'" if column.table else "" + line = column.this.meta.get("line") + col = column.this.meta.get("col") + start = column.this.meta.get("start") + end = column.this.meta.get("end") + + error_msg = f"Column '{column.name}' could not be resolved{for_table}." + if line and col: + error_msg += f" Line: {line}, Col: {col}" + if sql and start is not None and end is not None: + formatted_sql = highlight_sql(sql, [(start, end)])[0] + error_msg += f"\n {formatted_sql}" + + raise OptimizeError(error_msg) + + if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: + # New columns produced by the UNPIVOT can't be qualified, but there may be columns + # under the UNPIVOT's IN clause that can and should be qualified. We recompute + # this list here to ensure those in the former category will be excluded. + unpivot_columns = set(_unpivot_columns(scope.pivots[0])) + unqualified_columns = [ + c for c in unqualified_columns if c not in unpivot_columns + ] + + all_unqualified_columns.extend(unqualified_columns) + + if all_unqualified_columns: + first_column = all_unqualified_columns[0] + line = first_column.this.meta.get("line") + col = first_column.this.meta.get("col") + start = first_column.this.meta.get("start") + end = first_column.this.meta.get("end") + + error_msg = f"Ambiguous column '{first_column.name}'" + if line and col: + error_msg += f" (Line: {line}, Col: {col})" + if sql and start is not None and end is not None: + formatted_sql = highlight_sql(sql, [(start, end)])[0] + error_msg += f"\n {formatted_sql}" + + raise OptimizeError(error_msg) + + return expression + + +def _separate_pseudocolumns(scope: Scope, pseudocolumns: t.Set[str]) -> None: + if not pseudocolumns: + return + + has_pseudocolumns = False + scope_expression = scope.expression + + for column in scope.columns: + name = column.name.upper() + if name not in pseudocolumns: + continue + + if name != "LEVEL" or ( + isinstance(scope_expression, exp.Select) + and scope_expression.args.get("connect") + ): + column.replace(exp.Pseudocolumn(**column.args)) + has_pseudocolumns = True + + if has_pseudocolumns: + scope.clear_cache() + + +def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: + name_columns = [ + field.this + for field in unpivot.fields + if isinstance(field, exp.In) and isinstance(field.this, exp.Column) + ] + value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) + + return itertools.chain(name_columns, value_columns) + + +def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: + """ + Remove table column aliases. + + For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) + """ + for derived_table in derived_tables: + if ( + isinstance(derived_table.parent, exp.With) + and derived_table.parent.recursive + ): + continue + table_alias = derived_table.args.get("alias") + if table_alias: + table_alias.set("columns", None) + + +def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: + columns = {} + + def _update_source_columns(source_name: str) -> None: + for column_name in resolver.get_source_columns(source_name): + if column_name not in columns: + columns[column_name] = source_name + + joins = list(scope.find_all(exp.Join)) + names = {join.alias_or_name for join in joins} + ordered = [key for key in scope.selected_sources if key not in names] + + if names and not ordered: + raise OptimizeError(f"Joins {names} missing source table {scope.expression}") + + # Mapping of automatically joined column names to an ordered set of source names (dict). + column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} + + for source_name in ordered: + _update_source_columns(source_name) + + for i, join in enumerate(joins): + source_table = ordered[-1] + if source_table: + _update_source_columns(source_table) + + join_table = join.alias_or_name + ordered.append(join_table) + + using = join.args.get("using") + if not using: + continue + + join_columns = resolver.get_source_columns(join_table) + conditions = [] + using_identifier_count = len(using) + is_semi_or_anti_join = join.is_semi_or_anti_join + + for identifier in using: + identifier = identifier.name + table = columns.get(identifier) + + if not table or identifier not in join_columns: + if (columns and "*" not in columns) and join_columns: + raise OptimizeError(f"Cannot automatically join: {identifier}") + + table = table or source_table + + if i == 0 or using_identifier_count == 1: + lhs: exp.Expression = exp.column(identifier, table=table) + else: + coalesce_columns = [ + exp.column(identifier, table=t) + for t in ordered[:-1] + if identifier in resolver.get_source_columns(t) + ] + if len(coalesce_columns) > 1: + lhs = exp.func("coalesce", *coalesce_columns) + else: + lhs = exp.column(identifier, table=table) + + conditions.append(lhs.eq(exp.column(identifier, table=join_table))) + + # Set all values in the dict to None, because we only care about the key ordering + tables = column_tables.setdefault(identifier, {}) + + # Do not update the dict if this was a SEMI/ANTI join in + # order to avoid generating COALESCE columns for this join pair + if not is_semi_or_anti_join: + if table not in tables: + tables[table] = None + if join_table not in tables: + tables[join_table] = None + + join.set("using", None) + join.set("on", exp.and_(*conditions, copy=False)) + + if column_tables: + for column in scope.columns: + if not column.table and column.name in column_tables: + tables = column_tables[column.name] + coalesce_args = [ + exp.column(column.name, table=table) for table in tables + ] + replacement: exp.Expression = exp.func("coalesce", *coalesce_args) + + if isinstance(column.parent, exp.Select): + # Ensure the USING column keeps its name if it's projected + replacement = alias(replacement, alias=column.name, copy=False) + elif isinstance(column.parent, exp.Struct): + # Ensure the USING column keeps its name if it's an anonymous STRUCT field + replacement = exp.PropertyEQ( + this=exp.to_identifier(column.name), expression=replacement + ) + + scope.replace(column, replacement) + + return column_tables + + +def _expand_alias_refs( + scope: Scope, + resolver: Resolver, + dialect: Dialect, + expand_only_groupby: bool = False, +) -> None: + """ + Expand references to aliases. + Example: + SELECT y.foo AS bar, bar * 2 AS baz FROM y + => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y + """ + expression = scope.expression + + if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION: + return + + alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} + projections = {s.alias_or_name for s in expression.selects} + replaced = False + + def replace_columns( + node: t.Optional[exp.Expression], + resolve_table: bool = False, + literal_index: bool = False, + ) -> None: + nonlocal replaced + is_group_by = isinstance(node, exp.Group) + is_having = isinstance(node, exp.Having) + if not node or (expand_only_groupby and not is_group_by): + return + + for column in walk_in_scope(node, prune=lambda node: node.is_star): + if not isinstance(column, exp.Column): + continue + + # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: + # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded + # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) + # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns + if expand_only_groupby and is_group_by and column.parent is not node: + continue + + skip_replace = False + table = ( + resolver.get_table(column.name) + if resolve_table and not column.table + else None + ) + alias_expr, i = alias_to_expression.get(column.name, (None, 1)) + + if alias_expr: + skip_replace = bool( + alias_expr.find(exp.AggFunc) + and column.find_ancestor(exp.AggFunc) + and not isinstance( + column.find_ancestor(exp.Window, exp.Select), exp.Window + ) + ) + + # BigQuery's having clause gets confused if an alias matches a source. + # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; + # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b) + # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed" + if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES: + skip_replace = skip_replace or any( + node.parts[0].name in projections + for node in alias_expr.find_all(exp.Column) + ) + elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and ( + is_group_by or is_having + ): + column_table = table.name if table else column.table + if column_table in projections: + # BigQuery's GROUP BY and HAVING clauses get confused if the column name + # matches a source name and a projection. For instance: + # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1 + # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table + # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause" + column.replace(exp.to_identifier(column.name)) + replaced = True + return + + if table and (not alias_expr or skip_replace): + column.set("table", table) + elif not column.table and alias_expr and not skip_replace: + if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and ( + literal_index or resolve_table + ): + if literal_index: + column.replace(exp.Literal.number(i)) + replaced = True + else: + replaced = True + column = column.replace(exp.paren(alias_expr)) + simplified = simplify_parens(column, dialect) + if simplified is not column: + column.replace(simplified) + + for i, projection in enumerate(expression.selects): + replace_columns(projection) + if isinstance(projection, exp.Alias): + alias_to_expression[projection.alias] = (projection.this, i + 1) + + parent_scope = scope + on_right_sub_tree = False + while parent_scope and not parent_scope.is_cte: + if parent_scope.is_union: + on_right_sub_tree = ( + parent_scope.parent.expression.right is parent_scope.expression + ) + parent_scope = parent_scope.parent + + # We shouldn't expand aliases if they match the recursive CTE's columns + # and we are in the recursive part (right sub tree) of the CTE + if parent_scope and on_right_sub_tree: + cte = parent_scope.expression.parent + if cte.find_ancestor(exp.With).recursive: + for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: + alias_to_expression.pop(recursive_cte_column.output_name, None) + + replace_columns(expression.args.get("where")) + replace_columns(expression.args.get("group"), literal_index=True) + replace_columns(expression.args.get("having"), resolve_table=True) + replace_columns(expression.args.get("qualify"), resolve_table=True) + + if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS: + for join in expression.args.get("joins") or []: + replace_columns(join) + + if replaced: + scope.clear_cache() + + +def _expand_group_by(scope: Scope, dialect: Dialect) -> None: + expression = scope.expression + group = expression.args.get("group") + if not group: + return + + group.set( + "expressions", _expand_positional_references(scope, group.expressions, dialect) + ) + expression.set("group", group) + + +def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: + for modifier_key in ("order", "distinct"): + modifier = scope.expression.args.get(modifier_key) + if isinstance(modifier, exp.Distinct): + modifier = modifier.args.get("on") + + if not isinstance(modifier, exp.Expression): + continue + + modifier_expressions = modifier.expressions + if modifier_key == "order": + modifier_expressions = [ordered.this for ordered in modifier_expressions] + + for original, expanded in zip( + modifier_expressions, + _expand_positional_references( + scope, modifier_expressions, resolver.dialect, alias=True + ), + ): + for agg in original.find_all(exp.AggFunc): + for col in agg.find_all(exp.Column): + if not col.table: + col.set("table", resolver.get_table(col.name)) + + original.replace(expanded) + + if scope.expression.args.get("group"): + selects = { + s.this: exp.column(s.alias_or_name) for s in scope.expression.selects + } + + for expression in modifier_expressions: + expression.replace( + exp.to_identifier(_select_by_pos(scope, expression).alias) + if expression.is_int + else selects.get(expression, expression) + ) + + +def _expand_positional_references( + scope: Scope, + expressions: t.Iterable[exp.Expression], + dialect: Dialect, + alias: bool = False, +) -> t.List[exp.Expression]: + new_nodes: t.List[exp.Expression] = [] + ambiguous_projections = None + + for node in expressions: + if node.is_int: + select = _select_by_pos(scope, t.cast(exp.Literal, node)) + + if alias: + new_nodes.append(exp.column(select.args["alias"].copy())) + else: + select = select.this + + if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES: + if ambiguous_projections is None: + # When a projection name is also a source name and it is referenced in the + # GROUP BY clause, BQ can't understand what the identifier corresponds to + ambiguous_projections = { + s.alias_or_name + for s in scope.expression.selects + if s.alias_or_name in scope.selected_sources + } + + ambiguous = any( + column.parts[0].name in ambiguous_projections + for column in select.find_all(exp.Column) + ) + else: + ambiguous = False + + if ( + isinstance(select, exp.CONSTANTS) + or select.is_number + or select.find(exp.Explode, exp.Unnest) + or ambiguous + ): + new_nodes.append(node) + else: + new_nodes.append(select.copy()) + else: + new_nodes.append(node) + + return new_nodes + + +def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: + try: + return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) + except IndexError: + raise OptimizeError(f"Unknown output column: {node.name}") + + +def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: + """ + Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`. + + These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be + normalized to `Dot(Dot(...(., field1), field2, ...))` to be qualified properly. + """ + converted = False + for column in itertools.chain(scope.columns, scope.stars): + if isinstance(column, exp.Dot): + continue + + column_table: t.Optional[str | exp.Identifier] = column.table + dot_parts = column.meta.pop("dot_parts", []) + if ( + column_table + and column_table not in scope.sources + and ( + not scope.parent + or column_table not in scope.parent.sources + or not scope.is_correlated_subquery + ) + ): + root, *parts = column.parts + + if root.name in scope.sources: + # The struct is already qualified, but we still need to change the AST + column_table = root + root, *parts = parts + was_qualified = True + else: + column_table = resolver.get_table(root.name) + was_qualified = False + + if column_table: + converted = True + new_column = exp.column(root, table=column_table) + + if dot_parts: + # Remove the actual column parts from the rest of dot parts + new_column.meta["dot_parts"] = dot_parts[ + 2 if was_qualified else 1 : + ] + + column.replace(exp.Dot.build([new_column, *parts])) + + if converted: + # We want to re-aggregate the converted columns, otherwise they'd be skipped in + # a `for column in scope.columns` iteration, even though they shouldn't be + scope.clear_cache() + + +def _qualify_columns( + scope: Scope, + resolver: Resolver, + allow_partial_qualification: bool, +) -> None: + """Disambiguate columns, ensuring each column specifies a source""" + for column in scope.columns: + column_table = column.table + column_name = column.name + + if column_table and column_table in scope.sources: + source_columns = resolver.get_source_columns(column_table) + if ( + not allow_partial_qualification + and source_columns + and column_name not in source_columns + and "*" not in source_columns + ): + raise OptimizeError(f"Unknown column: {column_name}") + + if not column_table: + if scope.pivots and not column.find_ancestor(exp.Pivot): + # If the column is under the Pivot expression, we need to qualify it + # using the name of the pivoted source instead of the pivot's alias + column.set("table", exp.to_identifier(scope.pivots[0].alias)) + continue + + # column_table can be a '' because bigquery unnest has no table alias + column_table = resolver.get_table(column) + + if column_table: + column.set("table", column_table) + elif ( + resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS + and len(column.parts) == 1 + and column_name in scope.selected_sources + ): + # BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records + scope.replace(column, exp.TableColumn(this=column.this)) + + for pivot in scope.pivots: + for column in pivot.find_all(exp.Column): + if not column.table and column.name in resolver.all_columns: + column_table = resolver.get_table(column.name) + if column_table: + column.set("table", column_table) + + +def _expand_struct_stars_no_parens( + expression: exp.Dot, +) -> t.List[exp.Alias]: + """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" + + dot_column = expression.find(exp.Column) + if not isinstance(dot_column, exp.Column) or not dot_column.is_type( + exp.DataType.Type.STRUCT + ): + return [] + + # All nested struct values are ColumnDefs, so normalize the first exp.Column in one + dot_column = dot_column.copy() + starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) + + # First part is the table name and last part is the star so they can be dropped + dot_parts = expression.parts[1:-1] + + # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) + for part in dot_parts[1:]: + for field in t.cast(exp.DataType, starting_struct.kind).expressions: + # Unable to expand star unless all fields are named + if not isinstance(field.this, exp.Identifier): + return [] + + if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): + starting_struct = field + break + else: + # There is no matching field in the struct + return [] + + taken_names = set() + new_selections = [] + + for field in t.cast(exp.DataType, starting_struct.kind).expressions: + name = field.name + + # Ambiguous or anonymous fields can't be expanded + if name in taken_names or not isinstance(field.this, exp.Identifier): + return [] + + taken_names.add(name) + + this = field.this.copy() + root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] + new_column = exp.column( + t.cast(exp.Identifier, root), + table=dot_column.args.get("table"), + fields=t.cast(t.List[exp.Identifier], parts), + ) + new_selections.append(alias(new_column, this, copy=False)) + + return new_selections + + +def _expand_struct_stars_with_parens(expression: exp.Dot) -> t.List[exp.Alias]: + """[RisingWave] Expand/Flatten (.bar).*, where bar is a struct column""" + + # it is not ().* pattern, which means we can't expand + if not isinstance(expression.this, exp.Paren): + return [] + + # find column definition to get data-type + dot_column = expression.find(exp.Column) + if not isinstance(dot_column, exp.Column) or not dot_column.is_type( + exp.DataType.Type.STRUCT + ): + return [] + + parent = dot_column.parent + starting_struct = dot_column.type + + # walk up AST and down into struct definition in sync + while parent is not None: + if isinstance(parent, exp.Paren): + parent = parent.parent + continue + + # if parent is not a dot, then something is wrong + if not isinstance(parent, exp.Dot): + return [] + + # if the rhs of the dot is star we are done + rhs = parent.right + if isinstance(rhs, exp.Star): + break + + # if it is not identifier, then something is wrong + if not isinstance(rhs, exp.Identifier): + return [] + + # Check if current rhs identifier is in struct + matched = False + for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: + if struct_field_def.name == rhs.name: + matched = True + starting_struct = struct_field_def.kind # update struct + break + + if not matched: + return [] + + parent = parent.parent + + # build new aliases to expand star + new_selections = [] + + # fetch the outermost parentheses for new aliaes + outer_paren = expression.this + + for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: + new_identifier = struct_field_def.this.copy() + new_dot = exp.Dot.build([outer_paren.copy(), new_identifier]) + new_alias = alias(new_dot, new_identifier, copy=False) + new_selections.append(new_alias) + + return new_selections + + +def _expand_stars( + scope: Scope, + resolver: Resolver, + using_column_tables: t.Dict[str, t.Any], + pseudocolumns: t.Set[str], + annotator: TypeAnnotator, +) -> None: + """Expand stars to lists of column selections""" + + new_selections: t.List[exp.Expression] = [] + except_columns: t.Dict[int, t.Set[str]] = {} + replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} + rename_columns: t.Dict[int, t.Dict[str, str]] = {} + + coalesced_columns = set() + dialect = resolver.dialect + + pivot_output_columns = None + pivot_exclude_columns: t.Set[str] = set() + + pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) + if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: + if pivot.unpivot: + pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] + + for field in pivot.fields: + if isinstance(field, exp.In): + pivot_exclude_columns.update( + c.output_name + for e in field.expressions + for c in e.find_all(exp.Column) + ) + + else: + pivot_exclude_columns = set( + c.output_name for c in pivot.find_all(exp.Column) + ) + + pivot_output_columns = [ + c.output_name for c in pivot.args.get("columns", []) + ] + if not pivot_output_columns: + pivot_output_columns = [c.alias_or_name for c in pivot.expressions] + + if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any( + isinstance(col, exp.Dot) for col in scope.stars + ): + # Found struct expansion, annotate scope ahead of time + annotator.annotate_scope(scope) + + for expression in scope.expression.selects: + tables = [] + if isinstance(expression, exp.Star): + tables.extend(scope.selected_sources) + _add_except_columns(expression, tables, except_columns) + _add_replace_columns(expression, tables, replace_columns) + _add_rename_columns(expression, tables, rename_columns) + elif expression.is_star: + if not isinstance(expression, exp.Dot): + tables.append(expression.table) + _add_except_columns(expression.this, tables, except_columns) + _add_replace_columns(expression.this, tables, replace_columns) + _add_rename_columns(expression.this, tables, rename_columns) + elif ( + dialect.SUPPORTS_STRUCT_STAR_EXPANSION + and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS + ): + struct_fields = _expand_struct_stars_no_parens(expression) + if struct_fields: + new_selections.extend(struct_fields) + continue + elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS: + struct_fields = _expand_struct_stars_with_parens(expression) + if struct_fields: + new_selections.extend(struct_fields) + continue + + if not tables: + new_selections.append(expression) + continue + + for table in tables: + if table not in scope.sources: + raise OptimizeError(f"Unknown table: {table}") + + columns = resolver.get_source_columns(table, only_visible=True) + columns = columns or scope.outer_columns + + if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR: + columns = [ + name for name in columns if name.upper() not in pseudocolumns + ] + + if not columns or "*" in columns: + return + + table_id = id(table) + columns_to_exclude = except_columns.get(table_id) or set() + renamed_columns = rename_columns.get(table_id, {}) + replaced_columns = replace_columns.get(table_id, {}) + + if pivot: + if pivot_output_columns and pivot_exclude_columns: + pivot_columns = [ + c for c in columns if c not in pivot_exclude_columns + ] + pivot_columns.extend(pivot_output_columns) + else: + pivot_columns = pivot.alias_column_names + + if pivot_columns: + new_selections.extend( + alias(exp.column(name, table=pivot.alias), name, copy=False) + for name in pivot_columns + if name not in columns_to_exclude + ) + continue + + for name in columns: + if name in columns_to_exclude or name in coalesced_columns: + continue + if name in using_column_tables and table in using_column_tables[name]: + coalesced_columns.add(name) + tables = using_column_tables[name] + coalesce_args = [exp.column(name, table=table) for table in tables] + + new_selections.append( + alias( + exp.func("coalesce", *coalesce_args), alias=name, copy=False + ) + ) + else: + alias_ = renamed_columns.get(name, name) + selection_expr = replaced_columns.get(name) or exp.column( + name, table=table + ) + new_selections.append( + alias(selection_expr, alias_, copy=False) + if alias_ != name + else selection_expr + ) + + # Ensures we don't overwrite the initial selections with an empty list + if new_selections and isinstance(scope.expression, exp.Select): + scope.expression.set("expressions", new_selections) + + +def _add_except_columns( + expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] +) -> None: + except_ = expression.args.get("except_") + + if not except_: + return + + columns = {e.name for e in except_} + + for table in tables: + except_columns[id(table)] = columns + + +def _add_rename_columns( + expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] +) -> None: + rename = expression.args.get("rename") + + if not rename: + return + + columns = {e.this.name: e.alias for e in rename} + + for table in tables: + rename_columns[id(table)] = columns + + +def _add_replace_columns( + expression: exp.Expression, + tables, + replace_columns: t.Dict[int, t.Dict[str, exp.Alias]], +) -> None: + replace = expression.args.get("replace") + + if not replace: + return + + columns = {e.alias: e for e in replace} + + for table in tables: + replace_columns[id(table)] = columns + + +def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: + """Ensure all output columns are aliased""" + if isinstance(scope_or_expression, exp.Expression): + scope = build_scope(scope_or_expression) + if not isinstance(scope, Scope): + return + else: + scope = scope_or_expression + + new_selections = [] + for i, (selection, aliased_column) in enumerate( + itertools.zip_longest(scope.expression.selects, scope.outer_columns) + ): + if selection is None or isinstance(selection, exp.QueryTransform): + break + + if isinstance(selection, exp.Subquery): + if not selection.output_name: + selection.set( + "alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")) + ) + elif ( + not isinstance(selection, (exp.Alias, exp.Aliases)) + and not selection.is_star + ): + selection = alias( + selection, + alias=selection.output_name or f"_col_{i}", + copy=False, + ) + if aliased_column: + selection.set("alias", exp.to_identifier(aliased_column)) + + new_selections.append(selection) + + if new_selections and isinstance(scope.expression, exp.Select): + scope.expression.set("expressions", new_selections) + + +def quote_identifiers( + expression: E, dialect: DialectType = None, identify: bool = True +) -> E: + """Makes sure all identifiers that need to be quoted are quoted.""" + return expression.transform( + Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False + ) # type: ignore + + +def pushdown_cte_alias_columns(scope: Scope) -> None: + """ + Pushes down the CTE alias columns into the projection, + + This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. + + Args: + scope: Scope to find ctes to pushdown aliases. + """ + for cte in scope.ctes: + if cte.alias_column_names and isinstance(cte.this, exp.Select): + new_expressions = [] + for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): + if isinstance(projection, exp.Alias): + projection.set("alias", exp.to_identifier(_alias)) + else: + projection = alias(projection, alias=_alias) + new_expressions.append(projection) + cte.this.set("expressions", new_expressions) diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/qualify_tables.py b/third_party/bigframes_vendored/sqlglot/optimizer/qualify_tables.py new file mode 100644 index 0000000000..428ed25b1d --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/qualify_tables.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect, DialectType +from bigframes_vendored.sqlglot.helper import ensure_list, name_sequence, seq_get +from bigframes_vendored.sqlglot.optimizer.normalize_identifiers import ( + normalize_identifiers, +) +from bigframes_vendored.sqlglot.optimizer.scope import Scope, traverse_scope + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + + +def qualify_tables( + expression: E, + db: t.Optional[str | exp.Identifier] = None, + catalog: t.Optional[str | exp.Identifier] = None, + on_qualify: t.Optional[t.Callable[[exp.Table], None]] = None, + dialect: DialectType = None, + canonicalize_table_aliases: bool = False, +) -> E: + """ + Rewrite sqlglot AST to have fully qualified tables. Join constructs such as + (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. + + Examples: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") + >>> qualify_tables(expression, db="db").sql() + 'SELECT 1 FROM db.tbl AS tbl' + >>> + >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") + >>> qualify_tables(expression).sql() + 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' + + Args: + expression: Expression to qualify + db: Database name + catalog: Catalog name + on_qualify: Callback after a table has been qualified. + dialect: The dialect to parse catalog and schema into. + canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources + instead of preserving table names. Defaults to False. + + Returns: + The qualified expression. + """ + dialect = Dialect.get_or_raise(dialect) + next_alias_name = name_sequence("_") + + if db := db or None: + db = exp.parse_identifier(db, dialect=dialect) + db.meta["is_table"] = True + db = normalize_identifiers(db, dialect=dialect) + if catalog := catalog or None: + catalog = exp.parse_identifier(catalog, dialect=dialect) + catalog.meta["is_table"] = True + catalog = normalize_identifiers(catalog, dialect=dialect) + + def _qualify(table: exp.Table) -> None: + if isinstance(table.this, exp.Identifier): + if db and not table.args.get("db"): + table.set("db", db.copy()) + if catalog and not table.args.get("catalog") and table.args.get("db"): + table.set("catalog", catalog.copy()) + + if (db or catalog) and not isinstance(expression, exp.Query): + with_ = expression.args.get("with_") or exp.With() + cte_names = {cte.alias_or_name for cte in with_.expressions} + + for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): + if isinstance(node, exp.Table) and node.name not in cte_names: + _qualify(node) + + def _set_alias( + expression: exp.Expression, + canonical_aliases: t.Dict[str, str], + target_alias: t.Optional[str] = None, + scope: t.Optional[Scope] = None, + normalize: bool = False, + columns: t.Optional[t.List[t.Union[str, exp.Identifier]]] = None, + ) -> None: + alias = expression.args.get("alias") or exp.TableAlias() + + if canonicalize_table_aliases: + new_alias_name = next_alias_name() + canonical_aliases[alias.name or target_alias or ""] = new_alias_name + elif not alias.name: + new_alias_name = target_alias or next_alias_name() + if normalize and target_alias: + new_alias_name = normalize_identifiers( + new_alias_name, dialect=dialect + ).name + else: + return + + alias.set("this", exp.to_identifier(new_alias_name)) + + if columns: + alias.set("columns", [exp.to_identifier(c) for c in columns]) + + expression.set("alias", alias) + + if scope: + scope.rename_source(None, new_alias_name) + + for scope in traverse_scope(expression): + local_columns = scope.local_columns + canonical_aliases: t.Dict[str, str] = {} + + for query in scope.subqueries: + subquery = query.parent + if isinstance(subquery, exp.Subquery): + subquery.unwrap().replace(subquery) + + for derived_table in scope.derived_tables: + unnested = derived_table.unnest() + if isinstance(unnested, exp.Table): + joins = unnested.args.get("joins") + unnested.set("joins", None) + derived_table.this.replace( + exp.select("*").from_(unnested.copy(), copy=False) + ) + derived_table.this.set("joins", joins) + + _set_alias(derived_table, canonical_aliases, scope=scope) + if pivot := seq_get(derived_table.args.get("pivots") or [], 0): + _set_alias(pivot, canonical_aliases) + + table_aliases = {} + + for name, source in scope.sources.items(): + if isinstance(source, exp.Table): + # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte + is_real_table_source = bool(name) + + if pivot := seq_get(source.args.get("pivots") or [], 0): + name = source.name + + table_this = source.this + table_alias = source.args.get("alias") + function_columns: t.List[t.Union[str, exp.Identifier]] = [] + if isinstance(table_this, exp.Func): + if not table_alias: + function_columns = ensure_list( + dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this)) + ) + elif columns := table_alias.columns: + function_columns = columns + elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES: + function_columns = ensure_list(source.alias_or_name) + source.set("alias", None) + name = None + + _set_alias( + source, + canonical_aliases, + target_alias=name or source.name or None, + normalize=True, + columns=function_columns, + ) + + source_fqn = ".".join(p.name for p in source.parts) + table_aliases[source_fqn] = source.args["alias"].this.copy() + + if pivot: + target_alias = source.alias if pivot.unpivot else None + _set_alias( + pivot, + canonical_aliases, + target_alias=target_alias, + normalize=True, + ) + + # This case corresponds to a pivoted CTE, we don't want to qualify that + if isinstance(scope.sources.get(source.alias_or_name), Scope): + continue + + if is_real_table_source: + _qualify(source) + + if on_qualify: + on_qualify(source) + elif isinstance(source, Scope) and source.is_udtf: + _set_alias(udtf := source.expression, canonical_aliases) + + table_alias = udtf.args["alias"] + + if isinstance(udtf, exp.Values) and not table_alias.columns: + column_aliases = [ + normalize_identifiers(i, dialect=dialect) + for i in dialect.generate_values_aliases(udtf) + ] + table_alias.set("columns", column_aliases) + + for table in scope.tables: + if not table.alias and isinstance(table.parent, (exp.From, exp.Join)): + _set_alias(table, canonical_aliases, target_alias=table.name) + + for column in local_columns: + table = column.table + + if column.db: + table_alias = table_aliases.get( + ".".join(p.name for p in column.parts[0:-1]) + ) + + if table_alias: + for p in exp.COLUMN_PARTS[1:]: + column.set(p, None) + + column.set("table", table_alias.copy()) + elif ( + canonical_aliases + and table + and (canonical_table := canonical_aliases.get(table, "")) + != column.table + ): + # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0 + column.set("table", exp.to_identifier(canonical_table)) + + return expression diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/resolver.py b/third_party/bigframes_vendored/sqlglot/optimizer/resolver.py new file mode 100644 index 0000000000..82d12a702c --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/resolver.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +import itertools +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect +from bigframes_vendored.sqlglot.errors import OptimizeError +from bigframes_vendored.sqlglot.helper import seq_get, SingleValuedMapping +from bigframes_vendored.sqlglot.optimizer.scope import Scope + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.schema import Schema + + +class Resolver: + """ + Helper for resolving columns. + + This is a class so we can lazily load some things and easily share them across functions. + """ + + def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): + self.scope = scope + self.schema = schema + self.dialect = schema.dialect or Dialect() + self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None + self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None + self._all_columns: t.Optional[t.Set[str]] = None + self._infer_schema = infer_schema + self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} + + def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]: + """ + Get the table for a column name. + + Args: + column: The column expression (or column name) to find the table for. + Returns: + The table name if it can be found/inferred. + """ + column_name = column if isinstance(column, str) else column.name + + table_name = self._get_table_name_from_sources(column_name) + + if not table_name and isinstance(column, exp.Column): + # Fall-back case: If we couldn't find the `table_name` from ALL of the sources, + # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition, + # we may be able to disambiguate based on the source order. + if join_context := self._get_column_join_context(column): + # In this case, the return value will be the join that _may_ be able to disambiguate the column + # and we can use the source columns available at that join to get the table name + # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below + try: + table_name = self._get_table_name_from_sources( + column_name, self._get_available_source_columns(join_context) + ) + except OptimizeError: + pass + + if not table_name and self._infer_schema: + sources_without_schema = tuple( + source + for source, columns in self._get_all_source_columns().items() + if not columns or "*" in columns + ) + if len(sources_without_schema) == 1: + table_name = sources_without_schema[0] + + if table_name not in self.scope.selected_sources: + return exp.to_identifier(table_name) + + node, _ = self.scope.selected_sources.get(table_name) + + if isinstance(node, exp.Query): + while node and node.alias != table_name: + node = node.parent + + node_alias = node.args.get("alias") + if node_alias: + return exp.to_identifier(node_alias.this) + + return exp.to_identifier(table_name) + + @property + def all_columns(self) -> t.Set[str]: + """All available columns of all sources in this scope""" + if self._all_columns is None: + self._all_columns = { + column + for columns in self._get_all_source_columns().values() + for column in columns + } + return self._all_columns + + def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: + if isinstance(expression, exp.Select): + return expression.named_selects + if isinstance(expression, exp.Subquery) and isinstance( + expression.this, exp.SetOperation + ): + # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting + return self.get_source_columns_from_set_op(expression.this) + if not isinstance(expression, exp.SetOperation): + raise OptimizeError(f"Unknown set operation: {expression}") + + set_op = expression + + # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME + on_column_list = set_op.args.get("on") + + if on_column_list: + # The resulting columns are the columns in the ON clause: + # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) + columns = [col.name for col in on_column_list] + elif set_op.side or set_op.kind: + side = set_op.side + kind = set_op.kind + + # Visit the children UNIONs (if any) in a post-order traversal + left = self.get_source_columns_from_set_op(set_op.left) + right = self.get_source_columns_from_set_op(set_op.right) + + # We use dict.fromkeys to deduplicate keys and maintain insertion order + if side == "LEFT": + columns = left + elif side == "FULL": + columns = list(dict.fromkeys(left + right)) + elif kind == "INNER": + columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) + else: + columns = set_op.named_selects + + return columns + + def get_source_columns( + self, name: str, only_visible: bool = False + ) -> t.Sequence[str]: + """Resolve the source columns for a given source `name`.""" + cache_key = (name, only_visible) + if cache_key not in self._get_source_columns_cache: + if name not in self.scope.sources: + raise OptimizeError(f"Unknown table: {name}") + + source = self.scope.sources[name] + + if isinstance(source, exp.Table): + columns = self.schema.column_names(source, only_visible) + elif isinstance(source, Scope) and isinstance( + source.expression, (exp.Values, exp.Unnest) + ): + columns = source.expression.named_selects + + # in bigquery, unnest structs are automatically scoped as tables, so you can + # directly select a struct field in a query. + # this handles the case where the unnest is statically defined. + if self.dialect.UNNEST_COLUMN_ONLY and isinstance( + source.expression, exp.Unnest + ): + unnest = source.expression + + # if type is not annotated yet, try to get it from the schema + if not unnest.type or unnest.type.is_type( + exp.DataType.Type.UNKNOWN + ): + unnest_expr = seq_get(unnest.expressions, 0) + if isinstance(unnest_expr, exp.Column) and self.scope.parent: + col_type = self._get_unnest_column_type(unnest_expr) + # extract element type if it's an ARRAY + if col_type and col_type.is_type(exp.DataType.Type.ARRAY): + element_types = col_type.expressions + if element_types: + unnest.type = element_types[0].copy() + else: + if col_type: + unnest.type = col_type.copy() + # check if the result type is a STRUCT - extract struct field names + if unnest.is_type(exp.DataType.Type.STRUCT): + for k in unnest.type.expressions: # type: ignore + columns.append(k.name) + elif isinstance(source, Scope) and isinstance( + source.expression, exp.SetOperation + ): + columns = self.get_source_columns_from_set_op(source.expression) + + else: + select = seq_get(source.expression.selects, 0) + + if isinstance(select, exp.QueryTransform): + # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html + schema = select.args.get("schema") + columns = ( + [c.name for c in schema.expressions] + if schema + else ["key", "value"] + ) + else: + columns = source.expression.named_selects + + node, _ = self.scope.selected_sources.get(name) or (None, None) + if isinstance(node, Scope): + column_aliases = node.expression.alias_column_names + elif isinstance(node, exp.Expression): + column_aliases = node.alias_column_names + else: + column_aliases = [] + + if column_aliases: + # If the source's columns are aliased, their aliases shadow the corresponding column names. + # This can be expensive if there are lots of columns, so only do this if column_aliases exist. + columns = [ + alias or name + for (name, alias) in itertools.zip_longest(columns, column_aliases) + ] + + self._get_source_columns_cache[cache_key] = columns + + return self._get_source_columns_cache[cache_key] + + def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: + if self._source_columns is None: + self._source_columns = { + source_name: self.get_source_columns(source_name) + for source_name, source in itertools.chain( + self.scope.selected_sources.items(), + self.scope.lateral_sources.items(), + ) + } + return self._source_columns + + def _get_table_name_from_sources( + self, + column_name: str, + source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None, + ) -> t.Optional[str]: + if not source_columns: + # If not supplied, get all sources to calculate unambiguous columns + if self._unambiguous_columns is None: + self._unambiguous_columns = self._get_unambiguous_columns( + self._get_all_source_columns() + ) + + unambiguous_columns = self._unambiguous_columns + else: + unambiguous_columns = self._get_unambiguous_columns(source_columns) + + return unambiguous_columns.get(column_name) + + def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]: + """ + Check if a column participating in a join can be qualified based on the source order. + """ + args = self.scope.expression.args + joins = args.get("joins") + + if not joins or args.get("laterals") or args.get("pivots"): + # Feature gap: We currently don't try to disambiguate columns if other sources + # (e.g laterals, pivots) exist alongside joins + return None + + join_ancestor = column.find_ancestor(exp.Join, exp.Select) + + if ( + isinstance(join_ancestor, exp.Join) + and join_ancestor.alias_or_name in self.scope.selected_sources + ): + # Ensure that the found ancestor is a join that contains an actual source, + # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b` + return join_ancestor + + return None + + def _get_available_source_columns( + self, join_ancestor: exp.Join + ) -> t.Dict[str, t.Sequence[str]]: + """ + Get the source columns that are available at the point where a column is referenced. + + For columns in JOIN conditions, this only includes tables that have been joined + up to that point. Example: + + ``` + SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ... + ``` ^ + | + +----------------------------------+ + | + ⌄ + The unqualified column `c` is not ambiguous if no other sources up until that + join i.e t_1, ..., t_n, contain a column named `c`. + + """ + args = self.scope.expression.args + + # Collect tables in order: FROM clause tables + joined tables up to current join + from_name = args["from_"].alias_or_name + available_sources = {from_name: self.get_source_columns(from_name)} + + for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]: + available_sources[join.alias_or_name] = self.get_source_columns( + join.alias_or_name + ) + + return available_sources + + def _get_unambiguous_columns( + self, source_columns: t.Dict[str, t.Sequence[str]] + ) -> t.Mapping[str, str]: + """ + Find all the unambiguous columns in sources. + + Args: + source_columns: Mapping of names to source columns. + + Returns: + Mapping of column name to source name. + """ + if not source_columns: + return {} + + source_columns_pairs = list(source_columns.items()) + + first_table, first_columns = source_columns_pairs[0] + + if len(source_columns_pairs) == 1: + # Performance optimization - avoid copying first_columns if there is only one table. + return SingleValuedMapping(first_columns, first_table) + + unambiguous_columns = {col: first_table for col in first_columns} + all_columns = set(unambiguous_columns) + + for table, columns in source_columns_pairs[1:]: + unique = set(columns) + ambiguous = all_columns.intersection(unique) + all_columns.update(columns) + + for column in ambiguous: + unambiguous_columns.pop(column, None) + for column in unique.difference(ambiguous): + unambiguous_columns[column] = table + + return unambiguous_columns + + def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]: + """ + Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table. + + Args: + column: The column expression being unnested. + + Returns: + The DataType of the column, or None if not found. + """ + scope = self.scope.parent + + # if column is qualified, use that table, otherwise disambiguate using the resolver + if column.table: + table_name = column.table + else: + # use the parent scope's resolver to disambiguate the column + parent_resolver = Resolver(scope, self.schema, self._infer_schema) + table_identifier = parent_resolver.get_table(column) + if not table_identifier: + return None + table_name = table_identifier.name + + source = scope.sources.get(table_name) + return self._get_column_type_from_scope(source, column) if source else None + + def _get_column_type_from_scope( + self, source: t.Union[Scope, exp.Table], column: exp.Column + ) -> t.Optional[exp.DataType]: + """ + Get a column's type by tracing through scopes/tables to find the base table. + + Args: + source: The source to search - can be a Scope (to iterate its sources) or a Table. + column: The column to find the type for. + + Returns: + The DataType of the column, or None if not found. + """ + if isinstance(source, exp.Table): + # base table - get the column type from schema + col_type: t.Optional[exp.DataType] = self.schema.get_column_type( + source, column + ) + if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN): + return col_type + elif isinstance(source, Scope): + # iterate over all sources in the scope + for source_name, nested_source in source.sources.items(): + col_type = self._get_column_type_from_scope(nested_source, column) + if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN): + return col_type + + return None diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/scope.py b/third_party/bigframes_vendored/sqlglot/optimizer/scope.py new file mode 100644 index 0000000000..d7bcbda5a7 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/scope.py @@ -0,0 +1,981 @@ +from __future__ import annotations + +from collections import defaultdict +from enum import auto, Enum +import itertools +import logging +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.errors import OptimizeError +from bigframes_vendored.sqlglot.helper import ensure_collection, find_new_name, seq_get + +logger = logging.getLogger("sqlglot") + +TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) + + +class ScopeType(Enum): + ROOT = auto() + SUBQUERY = auto() + DERIVED_TABLE = auto() + CTE = auto() + UNION = auto() + UDTF = auto() + + +class Scope: + """ + Selection scope. + + Attributes: + expression (exp.Select|exp.SetOperation): Root expression of this scope + sources (dict[str, exp.Table|Scope]): Mapping of source name to either + a Table expression or another Scope instance. For example: + SELECT * FROM x {"x": Table(this="x")} + SELECT * FROM x AS y {"y": Table(this="x")} + SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} + lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals + For example: + SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; + The LATERAL VIEW EXPLODE gets x as a source. + cte_sources (dict[str, Scope]): Sources from CTES + outer_columns (list[str]): If this is a derived table or CTE, and the outer query + defines a column list for the alias of this scope, this is that list of columns. + For example: + SELECT * FROM (SELECT ...) AS y(col1, col2) + The inner query would have `["col1", "col2"]` for its `outer_columns` + parent (Scope): Parent scope + scope_type (ScopeType): Type of this scope, relative to it's parent + subquery_scopes (list[Scope]): List of all child scopes for subqueries + cte_scopes (list[Scope]): List of all child scopes for CTEs + derived_table_scopes (list[Scope]): List of all child scopes for derived_tables + udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions + table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined + union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be + a list of the left and right child scopes. + """ + + def __init__( + self, + expression, + sources=None, + outer_columns=None, + parent=None, + scope_type=ScopeType.ROOT, + lateral_sources=None, + cte_sources=None, + can_be_correlated=None, + ): + self.expression = expression + self.sources = sources or {} + self.lateral_sources = lateral_sources or {} + self.cte_sources = cte_sources or {} + self.sources.update(self.lateral_sources) + self.sources.update(self.cte_sources) + self.outer_columns = outer_columns or [] + self.parent = parent + self.scope_type = scope_type + self.subquery_scopes = [] + self.derived_table_scopes = [] + self.table_scopes = [] + self.cte_scopes = [] + self.union_scopes = [] + self.udtf_scopes = [] + self.can_be_correlated = can_be_correlated + self.clear_cache() + + def clear_cache(self): + self._collected = False + self._raw_columns = None + self._table_columns = None + self._stars = None + self._derived_tables = None + self._udtfs = None + self._tables = None + self._ctes = None + self._subqueries = None + self._selected_sources = None + self._columns = None + self._external_columns = None + self._local_columns = None + self._join_hints = None + self._pivots = None + self._references = None + self._semi_anti_join_tables = None + + def branch( + self, + expression, + scope_type, + sources=None, + cte_sources=None, + lateral_sources=None, + **kwargs, + ): + """Branch from the current scope to a new, inner scope""" + return Scope( + expression=expression.unnest(), + sources=sources.copy() if sources else None, + parent=self, + scope_type=scope_type, + cte_sources={**self.cte_sources, **(cte_sources or {})}, + lateral_sources=lateral_sources.copy() if lateral_sources else None, + can_be_correlated=self.can_be_correlated + or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), + **kwargs, + ) + + def _collect(self): + self._tables = [] + self._ctes = [] + self._subqueries = [] + self._derived_tables = [] + self._udtfs = [] + self._raw_columns = [] + self._table_columns = [] + self._stars = [] + self._join_hints = [] + self._semi_anti_join_tables = set() + + for node in self.walk(bfs=False): + if node is self.expression: + continue + + if isinstance(node, exp.Dot) and node.is_star: + self._stars.append(node) + elif isinstance(node, exp.Column) and not isinstance( + node, exp.Pseudocolumn + ): + if isinstance(node.this, exp.Star): + self._stars.append(node) + else: + self._raw_columns.append(node) + elif isinstance(node, exp.Table) and not isinstance( + node.parent, exp.JoinHint + ): + parent = node.parent + if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: + self._semi_anti_join_tables.add(node.alias_or_name) + + self._tables.append(node) + elif isinstance(node, exp.JoinHint): + self._join_hints.append(node) + elif isinstance(node, exp.UDTF): + self._udtfs.append(node) + elif isinstance(node, exp.CTE): + self._ctes.append(node) + elif _is_derived_table(node) and _is_from_or_join(node): + self._derived_tables.append(node) + elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node): + self._subqueries.append(node) + elif isinstance(node, exp.TableColumn): + self._table_columns.append(node) + + self._collected = True + + def _ensure_collected(self): + if not self._collected: + self._collect() + + def walk(self, bfs=True, prune=None): + return walk_in_scope(self.expression, bfs=bfs, prune=None) + + def find(self, *expression_types, bfs=True): + return find_in_scope(self.expression, expression_types, bfs=bfs) + + def find_all(self, *expression_types, bfs=True): + return find_all_in_scope(self.expression, expression_types, bfs=bfs) + + def replace(self, old, new): + """ + Replace `old` with `new`. + + This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. + + Args: + old (exp.Expression): old node + new (exp.Expression): new node + """ + old.replace(new) + self.clear_cache() + + @property + def tables(self): + """ + List of tables in this scope. + + Returns: + list[exp.Table]: tables + """ + self._ensure_collected() + return self._tables + + @property + def ctes(self): + """ + List of CTEs in this scope. + + Returns: + list[exp.CTE]: ctes + """ + self._ensure_collected() + return self._ctes + + @property + def derived_tables(self): + """ + List of derived tables in this scope. + + For example: + SELECT * FROM (SELECT ...) <- that's a derived table + + Returns: + list[exp.Subquery]: derived tables + """ + self._ensure_collected() + return self._derived_tables + + @property + def udtfs(self): + """ + List of "User Defined Tabular Functions" in this scope. + + Returns: + list[exp.UDTF]: UDTFs + """ + self._ensure_collected() + return self._udtfs + + @property + def subqueries(self): + """ + List of subqueries in this scope. + + For example: + SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery + + Returns: + list[exp.Select | exp.SetOperation]: subqueries + """ + self._ensure_collected() + return self._subqueries + + @property + def stars(self) -> t.List[exp.Column | exp.Dot]: + """ + List of star expressions (columns or dots) in this scope. + """ + self._ensure_collected() + return self._stars + + @property + def columns(self): + """ + List of columns in this scope. + + Returns: + list[exp.Column]: Column instances in this scope, plus any + Columns that reference this scope from correlated subqueries. + """ + if self._columns is None: + self._ensure_collected() + columns = self._raw_columns + + external_columns = [ + column + for scope in itertools.chain( + self.subquery_scopes, + self.udtf_scopes, + (dts for dts in self.derived_table_scopes if dts.can_be_correlated), + ) + for column in scope.external_columns + ] + + named_selects = set(self.expression.named_selects) + + self._columns = [] + for column in columns + external_columns: + ancestor = column.find_ancestor( + exp.Select, + exp.Qualify, + exp.Order, + exp.Having, + exp.Hint, + exp.Table, + exp.Star, + exp.Distinct, + ) + if ( + not ancestor + or column.table + or isinstance(ancestor, exp.Select) + or ( + isinstance(ancestor, exp.Table) + and not isinstance(ancestor.this, exp.Func) + ) + or ( + isinstance(ancestor, (exp.Order, exp.Distinct)) + and ( + isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) + or not isinstance(ancestor.parent, exp.Select) + or column.name not in named_selects + ) + ) + or ( + isinstance(ancestor, exp.Star) + and not column.arg_key == "except_" + ) + ): + self._columns.append(column) + + return self._columns + + @property + def table_columns(self): + if self._table_columns is None: + self._ensure_collected() + + return self._table_columns + + @property + def selected_sources(self): + """ + Mapping of nodes and sources that are actually selected from in this scope. + + That is, all tables in a schema are selectable at any point. But a + table only becomes a selected source if it's included in a FROM or JOIN clause. + + Returns: + dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes + """ + if self._selected_sources is None: + result = {} + + for name, node in self.references: + if name in self._semi_anti_join_tables: + # The RHS table of SEMI/ANTI joins shouldn't be collected as a + # selected source + continue + + if name in result: + raise OptimizeError(f"Alias already used: {name}") + if name in self.sources: + result[name] = (node, self.sources[name]) + + self._selected_sources = result + return self._selected_sources + + @property + def references(self) -> t.List[t.Tuple[str, exp.Expression]]: + if self._references is None: + self._references = [] + + for table in self.tables: + self._references.append((table.alias_or_name, table)) + for expression in itertools.chain(self.derived_tables, self.udtfs): + self._references.append( + ( + _get_source_alias(expression), + expression + if expression.args.get("pivots") + else expression.unnest(), + ) + ) + + return self._references + + @property + def external_columns(self): + """ + Columns that appear to reference sources in outer scopes. + + Returns: + list[exp.Column]: Column instances that don't reference sources in the current scope. + """ + if self._external_columns is None: + if isinstance(self.expression, exp.SetOperation): + left, right = self.union_scopes + self._external_columns = left.external_columns + right.external_columns + else: + self._external_columns = [ + c + for c in self.columns + if c.table not in self.sources + and c.table not in self.semi_or_anti_join_tables + ] + + return self._external_columns + + @property + def local_columns(self): + """ + Columns in this scope that are not external. + + Returns: + list[exp.Column]: Column instances that reference sources in the current scope. + """ + if self._local_columns is None: + external_columns = set(self.external_columns) + self._local_columns = [c for c in self.columns if c not in external_columns] + + return self._local_columns + + @property + def unqualified_columns(self): + """ + Unqualified columns in the current scope. + + Returns: + list[exp.Column]: Unqualified columns + """ + return [c for c in self.columns if not c.table] + + @property + def join_hints(self): + """ + Hints that exist in the scope that reference tables + + Returns: + list[exp.JoinHint]: Join hints that are referenced within the scope + """ + if self._join_hints is None: + return [] + return self._join_hints + + @property + def pivots(self): + if not self._pivots: + self._pivots = [ + pivot + for _, node in self.references + for pivot in node.args.get("pivots") or [] + ] + + return self._pivots + + @property + def semi_or_anti_join_tables(self): + return self._semi_anti_join_tables or set() + + def source_columns(self, source_name): + """ + Get all columns in the current scope for a particular source. + + Args: + source_name (str): Name of the source + Returns: + list[exp.Column]: Column instances that reference `source_name` + """ + return [column for column in self.columns if column.table == source_name] + + @property + def is_subquery(self): + """Determine if this scope is a subquery""" + return self.scope_type == ScopeType.SUBQUERY + + @property + def is_derived_table(self): + """Determine if this scope is a derived table""" + return self.scope_type == ScopeType.DERIVED_TABLE + + @property + def is_union(self): + """Determine if this scope is a union""" + return self.scope_type == ScopeType.UNION + + @property + def is_cte(self): + """Determine if this scope is a common table expression""" + return self.scope_type == ScopeType.CTE + + @property + def is_root(self): + """Determine if this is the root scope""" + return self.scope_type == ScopeType.ROOT + + @property + def is_udtf(self): + """Determine if this scope is a UDTF (User Defined Table Function)""" + return self.scope_type == ScopeType.UDTF + + @property + def is_correlated_subquery(self): + """Determine if this scope is a correlated subquery""" + return bool(self.can_be_correlated and self.external_columns) + + def rename_source(self, old_name, new_name): + """Rename a source in this scope""" + old_name = old_name or "" + if old_name in self.sources: + self.sources[new_name] = self.sources.pop(old_name) + + def add_source(self, name, source): + """Add a source to this scope""" + self.sources[name] = source + self.clear_cache() + + def remove_source(self, name): + """Remove a source from this scope""" + self.sources.pop(name, None) + self.clear_cache() + + def __repr__(self): + return f"Scope<{self.expression.sql()}>" + + def traverse(self): + """ + Traverse the scope tree from this node. + + Yields: + Scope: scope instances in depth-first-search post-order + """ + stack = [self] + result = [] + while stack: + scope = stack.pop() + result.append(scope) + stack.extend( + itertools.chain( + scope.cte_scopes, + scope.union_scopes, + scope.table_scopes, + scope.subquery_scopes, + ) + ) + + yield from reversed(result) + + def ref_count(self): + """ + Count the number of times each scope in this tree is referenced. + + Returns: + dict[int, int]: Mapping of Scope instance ID to reference count + """ + scope_ref_count = defaultdict(lambda: 0) + + for scope in self.traverse(): + for _, source in scope.selected_sources.values(): + scope_ref_count[id(source)] += 1 + + for name in scope._semi_anti_join_tables: + # semi/anti join sources are not actually selected but we still need to + # increment their ref count to avoid them being optimized away + if name in scope.sources: + scope_ref_count[id(scope.sources[name])] += 1 + + return scope_ref_count + + +def traverse_scope(expression: exp.Expression) -> t.List[Scope]: + """ + Traverse an expression by its "scopes". + + "Scope" represents the current context of a Select statement. + + This is helpful for optimizing queries, where we need more information than + the expression tree itself. For example, we might care about the source + names within a subquery. Returns a list because a generator could result in + incomplete properties which is confusing. + + Examples: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") + >>> scopes = traverse_scope(expression) + >>> scopes[0].expression.sql(), list(scopes[0].sources) + ('SELECT a FROM x', ['x']) + >>> scopes[1].expression.sql(), list(scopes[1].sources) + ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) + + Args: + expression: Expression to traverse + + Returns: + A list of the created scope instances + """ + if isinstance(expression, TRAVERSABLES): + return list(_traverse_scope(Scope(expression))) + return [] + + +def build_scope(expression: exp.Expression) -> t.Optional[Scope]: + """ + Build a scope tree. + + Args: + expression: Expression to build the scope tree for. + + Returns: + The root scope + """ + return seq_get(traverse_scope(expression), -1) + + +def _traverse_scope(scope): + expression = scope.expression + + if isinstance(expression, exp.Select): + yield from _traverse_select(scope) + elif isinstance(expression, exp.SetOperation): + yield from _traverse_ctes(scope) + yield from _traverse_union(scope) + return + elif isinstance(expression, exp.Subquery): + if scope.is_root: + yield from _traverse_select(scope) + else: + yield from _traverse_subqueries(scope) + elif isinstance(expression, exp.Table): + yield from _traverse_tables(scope) + elif isinstance(expression, exp.UDTF): + yield from _traverse_udtfs(scope) + elif isinstance(expression, exp.DDL): + if isinstance(expression.expression, exp.Query): + yield from _traverse_ctes(scope) + yield from _traverse_scope( + Scope(expression.expression, cte_sources=scope.cte_sources) + ) + return + elif isinstance(expression, exp.DML): + yield from _traverse_ctes(scope) + for query in find_all_in_scope(expression, exp.Query): + # This check ensures we don't yield the CTE/nested queries twice + if not isinstance(query.parent, (exp.CTE, exp.Subquery)): + yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) + return + else: + logger.warning( + "Cannot traverse scope %s with type '%s'", expression, type(expression) + ) + return + + yield scope + + +def _traverse_select(scope): + yield from _traverse_ctes(scope) + yield from _traverse_tables(scope) + yield from _traverse_subqueries(scope) + + +def _traverse_union(scope): + prev_scope = None + union_scope_stack = [scope] + expression_stack = [scope.expression.right, scope.expression.left] + + while expression_stack: + expression = expression_stack.pop() + union_scope = union_scope_stack[-1] + + new_scope = union_scope.branch( + expression, + outer_columns=union_scope.outer_columns, + scope_type=ScopeType.UNION, + ) + + if isinstance(expression, exp.SetOperation): + yield from _traverse_ctes(new_scope) + + union_scope_stack.append(new_scope) + expression_stack.extend([expression.right, expression.left]) + continue + + for scope in _traverse_scope(new_scope): + yield scope + + if prev_scope: + union_scope_stack.pop() + union_scope.union_scopes = [prev_scope, scope] + prev_scope = union_scope + + yield union_scope + else: + prev_scope = scope + + +def _traverse_ctes(scope): + sources = {} + + for cte in scope.ctes: + cte_name = cte.alias + + # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. + # thus the recursive scope is the first section of the union. + with_ = scope.expression.args.get("with_") + if with_ and with_.recursive: + union = cte.this + + if isinstance(union, exp.SetOperation): + sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) + + child_scope = None + + for child_scope in _traverse_scope( + scope.branch( + cte.this, + cte_sources=sources, + outer_columns=cte.alias_column_names, + scope_type=ScopeType.CTE, + ) + ): + yield child_scope + + # append the final child_scope yielded + if child_scope: + sources[cte_name] = child_scope + scope.cte_scopes.append(child_scope) + + scope.sources.update(sources) + scope.cte_sources.update(sources) + + +def _is_derived_table(expression: exp.Subquery) -> bool: + """ + We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", + as it doesn't introduce a new scope. If an alias is present, it shadows all names + under the Subquery, so that's one exception to this rule. + """ + return isinstance(expression, exp.Subquery) and bool( + expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) + ) + + +def _is_from_or_join(expression: exp.Expression) -> bool: + """ + Determine if `expression` is the FROM or JOIN clause of a SELECT statement. + """ + parent = expression.parent + + # Subqueries can be arbitrarily nested + while isinstance(parent, exp.Subquery): + parent = parent.parent + + return isinstance(parent, (exp.From, exp.Join)) + + +def _traverse_tables(scope): + sources = {} + + # Traverse FROMs, JOINs, and LATERALs in the order they are defined + expressions = [] + from_ = scope.expression.args.get("from_") + if from_: + expressions.append(from_.this) + + for join in scope.expression.args.get("joins") or []: + expressions.append(join.this) + + if isinstance(scope.expression, exp.Table): + expressions.append(scope.expression) + + expressions.extend(scope.expression.args.get("laterals") or []) + + for expression in expressions: + if isinstance(expression, exp.Final): + expression = expression.this + if isinstance(expression, exp.Table): + table_name = expression.name + source_name = expression.alias_or_name + + if table_name in scope.sources and not expression.db: + # This is a reference to a parent source (e.g. a CTE), not an actual table, unless + # it is pivoted, because then we get back a new table and hence a new source. + pivots = expression.args.get("pivots") + if pivots: + sources[pivots[0].alias] = expression + else: + sources[source_name] = scope.sources[table_name] + elif source_name in sources: + sources[find_new_name(sources, table_name)] = expression + else: + sources[source_name] = expression + + # Make sure to not include the joins twice + if expression is not scope.expression: + expressions.extend( + join.this for join in expression.args.get("joins") or [] + ) + + continue + + if not isinstance(expression, exp.DerivedTable): + continue + + if isinstance(expression, exp.UDTF): + lateral_sources = sources + scope_type = ScopeType.UDTF + scopes = scope.udtf_scopes + elif _is_derived_table(expression): + lateral_sources = None + scope_type = ScopeType.DERIVED_TABLE + scopes = scope.derived_table_scopes + expressions.extend(join.this for join in expression.args.get("joins") or []) + else: + # Makes sure we check for possible sources in nested table constructs + expressions.append(expression.this) + expressions.extend(join.this for join in expression.args.get("joins") or []) + continue + + child_scope = None + + for child_scope in _traverse_scope( + scope.branch( + expression, + lateral_sources=lateral_sources, + outer_columns=expression.alias_column_names, + scope_type=scope_type, + ) + ): + yield child_scope + + # Tables without aliases will be set as "" + # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. + # Until then, this means that only a single, unaliased derived table is allowed (rather, + # the latest one wins. + sources[_get_source_alias(expression)] = child_scope + + # append the final child_scope yielded + if child_scope: + scopes.append(child_scope) + scope.table_scopes.append(child_scope) + + scope.sources.update(sources) + + +def _traverse_subqueries(scope): + for subquery in scope.subqueries: + top = None + for child_scope in _traverse_scope( + scope.branch(subquery, scope_type=ScopeType.SUBQUERY) + ): + yield child_scope + top = child_scope + scope.subquery_scopes.append(top) + + +def _traverse_udtfs(scope): + if isinstance(scope.expression, exp.Unnest): + expressions = scope.expression.expressions + elif isinstance(scope.expression, exp.Lateral): + expressions = [scope.expression.this] + else: + expressions = [] + + sources = {} + for expression in expressions: + if isinstance(expression, exp.Subquery): + top = None + for child_scope in _traverse_scope( + scope.branch( + expression, + scope_type=ScopeType.SUBQUERY, + outer_columns=expression.alias_column_names, + ) + ): + yield child_scope + top = child_scope + sources[_get_source_alias(expression)] = child_scope + + scope.subquery_scopes.append(top) + + scope.sources.update(sources) + + +def walk_in_scope(expression, bfs=True, prune=None): + """ + Returns a generator object which visits all nodes in the syntrax tree, stopping at + nodes that start child scopes. + + Args: + expression (exp.Expression): + bfs (bool): if set to True the BFS traversal order will be applied, + otherwise the DFS traversal will be used instead. + prune ((node, parent, arg_key) -> bool): callable that returns True if + the generator should stop traversing this branch of the tree. + + Yields: + tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key + """ + # We'll use this variable to pass state into the dfs generator. + # Whenever we set it to True, we exclude a subtree from traversal. + crossed_scope_boundary = False + + for node in expression.walk( + bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) + ): + crossed_scope_boundary = False + + yield node + + if node is expression: + continue + + if ( + isinstance(node, exp.CTE) + or ( + isinstance(node.parent, (exp.From, exp.Join)) + and _is_derived_table(node) + ) + or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) + or isinstance(node, exp.UNWRAPPED_QUERIES) + ): + crossed_scope_boundary = True + + if isinstance(node, (exp.Subquery, exp.UDTF)): + # The following args are not actually in the inner scope, so we should visit them + for key in ("joins", "laterals", "pivots"): + for arg in node.args.get(key) or []: + yield from walk_in_scope(arg, bfs=bfs) + + +def find_all_in_scope(expression, expression_types, bfs=True): + """ + Returns a generator object which visits all nodes in this scope and only yields those that + match at least one of the specified expression types. + + This does NOT traverse into subscopes. + + Args: + expression (exp.Expression): + expression_types (tuple[type]|type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Yields: + exp.Expression: nodes + """ + for expression in walk_in_scope(expression, bfs=bfs): + if isinstance(expression, tuple(ensure_collection(expression_types))): + yield expression + + +def find_in_scope(expression, expression_types, bfs=True): + """ + Returns the first node in this scope which matches at least one of the specified types. + + This does NOT traverse into subscopes. + + Args: + expression (exp.Expression): + expression_types (tuple[type]|type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Returns: + exp.Expression: the node which matches the criteria or None if no node matching + the criteria was found. + """ + return next(find_all_in_scope(expression, expression_types, bfs=bfs), None) + + +def _get_source_alias(expression): + alias_arg = expression.args.get("alias") + alias_name = expression.alias + + if ( + not alias_name + and isinstance(alias_arg, exp.TableAlias) + and len(alias_arg.columns) == 1 + ): + alias_name = alias_arg.columns[0].name + + return alias_name diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/simplify.py b/third_party/bigframes_vendored/sqlglot/optimizer/simplify.py new file mode 100644 index 0000000000..5bdbed834b --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/simplify.py @@ -0,0 +1,1794 @@ +from __future__ import annotations + +from collections import defaultdict, deque +import datetime +import functools +from functools import reduce, wraps +import itertools +import logging +import typing as t + +import bigframes_vendored.sqlglot +from bigframes_vendored.sqlglot import Dialect, exp +from bigframes_vendored.sqlglot.helper import first, merge_ranges, while_changing +from bigframes_vendored.sqlglot.optimizer.annotate_types import TypeAnnotator +from bigframes_vendored.sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope +from bigframes_vendored.sqlglot.schema import ensure_schema + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + + DateRange = t.Tuple[datetime.date, datetime.date] + DateTruncBinaryTransform = t.Callable[ + [exp.Expression, datetime.date, str, Dialect, exp.DataType], + t.Optional[exp.Expression], + ] + + +logger = logging.getLogger("sqlglot") + + +# Final means that an expression should not be simplified +FINAL = "final" + +SIMPLIFIABLE = ( + exp.Binary, + exp.Func, + exp.Lambda, + exp.Predicate, + exp.Unary, +) + + +def simplify( + expression: exp.Expression, + constant_propagation: bool = False, + coalesce_simplification: bool = False, + dialect: DialectType = None, +): + """ + Rewrite sqlglot AST to simplify expressions. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("TRUE AND TRUE") + >>> simplify(expression).sql() + 'TRUE' + + Args: + expression: expression to simplify + constant_propagation: whether the constant propagation rule should be used + coalesce_simplification: whether the simplify coalesce rule should be used. + This rule tries to remove coalesce functions, which can be useful in certain analyses but + can leave the query more verbose. + Returns: + sqlglot.Expression: simplified expression + """ + return Simplifier(dialect=dialect).simplify( + expression, + constant_propagation=constant_propagation, + coalesce_simplification=coalesce_simplification, + ) + + +class UnsupportedUnit(Exception): + pass + + +def catch(*exceptions): + """Decorator that ignores a simplification function if any of `exceptions` are raised""" + + def decorator(func): + def wrapped(expression, *args, **kwargs): + try: + return func(expression, *args, **kwargs) + except exceptions: + return expression + + return wrapped + + return decorator + + +def annotate_types_on_change(func): + @wraps(func) + def _func( + self, expression: exp.Expression, *args, **kwargs + ) -> t.Optional[exp.Expression]: + new_expression = func(self, expression, *args, **kwargs) + + if new_expression is None: + return new_expression + + if self.annotate_new_expressions and expression != new_expression: + self._annotator.clear() + + # We annotate this to ensure new children nodes are also annotated + new_expression = self._annotator.annotate( + expression=new_expression, + annotate_scope=False, + ) + + # Whatever expression the original expression is transformed into needs to preserve + # the original type, otherwise the simplification could result in a different schema + new_expression.type = expression.type + + return new_expression + + return _func + + +def flatten(expression): + """ + A AND (B AND C) -> A AND B AND C + A OR (B OR C) -> A OR B OR C + """ + if isinstance(expression, exp.Connector): + for node in expression.args.values(): + child = node.unnest() + if isinstance(child, expression.__class__): + node.replace(child) + return expression + + +def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression: + if not isinstance(expression, exp.Paren): + return expression + + this = expression.this + parent = expression.parent + parent_is_predicate = isinstance(parent, exp.Predicate) + + if isinstance(this, exp.Select): + return expression + + if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): + return expression + + if ( + Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS + and isinstance(parent, exp.Dot) + and (isinstance(parent.right, (exp.Identifier, exp.Star))) + ): + return expression + + if ( + not isinstance(parent, (exp.Condition, exp.Binary)) + or isinstance(parent, exp.Paren) + or ( + not isinstance(this, exp.Binary) + and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) + ) + or ( + isinstance(this, exp.Predicate) + and not (parent_is_predicate or isinstance(parent, exp.Neg)) + ) + or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) + or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) + or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) + ): + return this + + return expression + + +def propagate_constants(expression, root=True): + """ + Propagate constants for conjunctions in DNF: + + SELECT * FROM t WHERE a = b AND b = 5 becomes + SELECT * FROM t WHERE a = 5 AND b = 5 + + Reference: https://www.sqlite.org/optoverview.html + """ + + if ( + isinstance(expression, exp.And) + and (root or not expression.same_parent) + and bigframes_vendored.sqlglot.optimizer.normalize.normalized( + expression, dnf=True + ) + ): + constant_mapping = {} + for expr in walk_in_scope( + expression, prune=lambda node: isinstance(node, exp.If) + ): + if isinstance(expr, exp.EQ): + l, r = expr.left, expr.right + + # TODO: create a helper that can be used to detect nested literal expressions such + # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too + if isinstance(l, exp.Column) and isinstance(r, exp.Literal): + constant_mapping[l] = (id(l), r) + + if constant_mapping: + for column in find_all_in_scope(expression, exp.Column): + parent = column.parent + column_id, constant = constant_mapping.get(column) or (None, None) + if ( + column_id is not None + and id(column) != column_id + and not ( + isinstance(parent, exp.Is) + and isinstance(parent.expression, exp.Null) + ) + ): + column.replace(constant.copy()) + + return expression + + +def _is_number(expression: exp.Expression) -> bool: + return expression.is_number + + +def _is_interval(expression: exp.Expression) -> bool: + return ( + isinstance(expression, exp.Interval) + and extract_interval(expression) is not None + ) + + +def _is_nonnull_constant(expression: exp.Expression) -> bool: + return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) + + +def _is_constant(expression: exp.Expression) -> bool: + return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) + + +def _datetrunc_range( + date: datetime.date, unit: str, dialect: Dialect +) -> t.Optional[DateRange]: + """ + Get the date range for a DATE_TRUNC equality comparison: + + Example: + _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) + Returns: + tuple of [min, max) or None if a value can never be equal to `date` for `unit` + """ + floor = date_floor(date, unit, dialect) + + if date != floor: + # This will always be False, except for NULL values. + return None + + return floor, floor + interval(unit) + + +def _datetrunc_eq_expression( + left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] +) -> exp.Expression: + """Get the logical expression for a date range""" + return exp.and_( + left >= date_literal(drange[0], target_type), + left < date_literal(drange[1], target_type), + copy=False, + ) + + +def _datetrunc_eq( + left: exp.Expression, + date: datetime.date, + unit: str, + dialect: Dialect, + target_type: t.Optional[exp.DataType], +) -> t.Optional[exp.Expression]: + drange = _datetrunc_range(date, unit, dialect) + if not drange: + return None + + return _datetrunc_eq_expression(left, drange, target_type) + + +def _datetrunc_neq( + left: exp.Expression, + date: datetime.date, + unit: str, + dialect: Dialect, + target_type: t.Optional[exp.DataType], +) -> t.Optional[exp.Expression]: + drange = _datetrunc_range(date, unit, dialect) + if not drange: + return None + + return exp.and_( + left < date_literal(drange[0], target_type), + left >= date_literal(drange[1], target_type), + copy=False, + ) + + +def always_true(expression): + return (isinstance(expression, exp.Boolean) and expression.this) or ( + isinstance(expression, exp.Literal) + and expression.is_number + and not is_zero(expression) + ) + + +def always_false(expression): + return is_false(expression) or is_null(expression) or is_zero(expression) + + +def is_zero(expression): + return isinstance(expression, exp.Literal) and expression.to_py() == 0 + + +def is_complement(a, b): + return isinstance(b, exp.Not) and b.this == a + + +def is_false(a: exp.Expression) -> bool: + return type(a) is exp.Boolean and not a.this + + +def is_null(a: exp.Expression) -> bool: + return type(a) is exp.Null + + +def eval_boolean(expression, a, b): + if isinstance(expression, (exp.EQ, exp.Is)): + return boolean_literal(a == b) + if isinstance(expression, exp.NEQ): + return boolean_literal(a != b) + if isinstance(expression, exp.GT): + return boolean_literal(a > b) + if isinstance(expression, exp.GTE): + return boolean_literal(a >= b) + if isinstance(expression, exp.LT): + return boolean_literal(a < b) + if isinstance(expression, exp.LTE): + return boolean_literal(a <= b) + return None + + +def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: + if isinstance(value, datetime.datetime): + return value.date() + if isinstance(value, datetime.date): + return value + try: + return datetime.datetime.fromisoformat(value).date() + except ValueError: + return None + + +def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + return datetime.datetime(year=value.year, month=value.month, day=value.day) + try: + return datetime.datetime.fromisoformat(value) + except ValueError: + return None + + +def cast_value( + value: t.Any, to: exp.DataType +) -> t.Optional[t.Union[datetime.date, datetime.date]]: + if not value: + return None + if to.is_type(exp.DataType.Type.DATE): + return cast_as_date(value) + if to.is_type(*exp.DataType.TEMPORAL_TYPES): + return cast_as_datetime(value) + return None + + +def extract_date( + cast: exp.Expression, +) -> t.Optional[t.Union[datetime.date, datetime.date]]: + if isinstance(cast, exp.Cast): + to = cast.to + elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): + to = exp.DataType.build(exp.DataType.Type.DATE) + else: + return None + + if isinstance(cast.this, exp.Literal): + value: t.Any = cast.this.name + elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): + value = extract_date(cast.this) + else: + return None + return cast_value(value, to) + + +def _is_date_literal(expression: exp.Expression) -> bool: + return extract_date(expression) is not None + + +def extract_interval(expression): + try: + n = int(expression.this.to_py()) + unit = expression.text("unit").lower() + return interval(unit, n) + except (UnsupportedUnit, ModuleNotFoundError, ValueError): + return None + + +def extract_type(*expressions): + target_type = None + for expression in expressions: + target_type = ( + expression.to if isinstance(expression, exp.Cast) else expression.type + ) + if target_type: + break + + return target_type + + +def date_literal(date, target_type=None): + if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): + target_type = ( + exp.DataType.Type.DATETIME + if isinstance(date, datetime.datetime) + else exp.DataType.Type.DATE + ) + + return exp.cast(exp.Literal.string(date), target_type) + + +def interval(unit: str, n: int = 1): + from dateutil.relativedelta import relativedelta + + if unit == "year": + return relativedelta(years=1 * n) + if unit == "quarter": + return relativedelta(months=3 * n) + if unit == "month": + return relativedelta(months=1 * n) + if unit == "week": + return relativedelta(weeks=1 * n) + if unit == "day": + return relativedelta(days=1 * n) + if unit == "hour": + return relativedelta(hours=1 * n) + if unit == "minute": + return relativedelta(minutes=1 * n) + if unit == "second": + return relativedelta(seconds=1 * n) + + raise UnsupportedUnit(f"Unsupported unit: {unit}") + + +def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: + if unit == "year": + return d.replace(month=1, day=1) + if unit == "quarter": + if d.month <= 3: + return d.replace(month=1, day=1) + elif d.month <= 6: + return d.replace(month=4, day=1) + elif d.month <= 9: + return d.replace(month=7, day=1) + else: + return d.replace(month=10, day=1) + if unit == "month": + return d.replace(month=d.month, day=1) + if unit == "week": + # Assuming week starts on Monday (0) and ends on Sunday (6) + return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) + if unit == "day": + return d + + raise UnsupportedUnit(f"Unsupported unit: {unit}") + + +def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: + floor = date_floor(d, unit, dialect) + + if floor == d: + return d + + return floor + interval(unit) + + +def boolean_literal(condition): + return exp.true() if condition else exp.false() + + +class Simplifier: + def __init__( + self, dialect: DialectType = None, annotate_new_expressions: bool = True + ): + self.dialect = Dialect.get_or_raise(dialect) + self.annotate_new_expressions = annotate_new_expressions + + self._annotator: TypeAnnotator = TypeAnnotator( + schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False + ) + + # Value ranges for byte-sized signed/unsigned integers + TINYINT_MIN = -128 + TINYINT_MAX = 127 + UTINYINT_MIN = 0 + UTINYINT_MAX = 255 + + COMPLEMENT_COMPARISONS = { + exp.LT: exp.GTE, + exp.GT: exp.LTE, + exp.LTE: exp.GT, + exp.GTE: exp.LT, + exp.EQ: exp.NEQ, + exp.NEQ: exp.EQ, + } + + COMPLEMENT_SUBQUERY_PREDICATES = { + exp.All: exp.Any, + exp.Any: exp.All, + } + + LT_LTE = (exp.LT, exp.LTE) + GT_GTE = (exp.GT, exp.GTE) + + COMPARISONS = ( + *LT_LTE, + *GT_GTE, + exp.EQ, + exp.NEQ, + exp.Is, + ) + + INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + exp.LT: exp.GT, + exp.GT: exp.LT, + exp.LTE: exp.GTE, + exp.GTE: exp.LTE, + } + + NONDETERMINISTIC = (exp.Rand, exp.Randn) + AND_OR = (exp.And, exp.Or) + + INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + exp.DateAdd: exp.Sub, + exp.DateSub: exp.Add, + exp.DatetimeAdd: exp.Sub, + exp.DatetimeSub: exp.Add, + } + + INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + **INVERSE_DATE_OPS, + exp.Add: exp.Sub, + exp.Sub: exp.Add, + } + + NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) + + CONCATS = (exp.Concat, exp.DPipe) + + DATETRUNC_BINARY_COMPARISONS: t.Dict[ + t.Type[exp.Expression], DateTruncBinaryTransform + ] = { + exp.LT: lambda ll, dt, u, d, t: ll + < date_literal( + dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t + ), + exp.GT: lambda ll, dt, u, d, t: ll + >= date_literal(date_floor(dt, u, d) + interval(u), t), + exp.LTE: lambda ll, dt, u, d, t: ll + < date_literal(date_floor(dt, u, d) + interval(u), t), + exp.GTE: lambda ll, dt, u, d, t: ll >= date_literal(date_ceil(dt, u, d), t), + exp.EQ: _datetrunc_eq, + exp.NEQ: _datetrunc_neq, + } + + DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} + DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) + + SAFE_CONNECTOR_ELIMINATION_RESULT = (exp.Connector, exp.Boolean) + + # CROSS joins result in an empty table if the right table is empty. + # So we can only simplify certain types of joins to CROSS. + # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x + JOINS = { + ("", ""), + ("", "INNER"), + ("RIGHT", ""), + ("RIGHT", "OUTER"), + } + + def simplify( + self, + expression: exp.Expression, + constant_propagation: bool = False, + coalesce_simplification: bool = False, + ): + wheres = [] + joins = [] + + for node in expression.walk( + prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL)) + ): + if node.meta.get(FINAL): + continue + + # group by expressions cannot be simplified, for example + # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 + # the projection must exactly match the group by key + group = node.args.get("group") + + if group and hasattr(node, "selects"): + groups = set(group.expressions) + group.meta[FINAL] = True + + for s in node.selects: + for n in s.walk(FINAL): + if n in groups: + s.meta[FINAL] = True + break + + having = node.args.get("having") + + if having: + for n in having.walk(): + if n in groups: + having.meta[FINAL] = True + break + + if isinstance(node, exp.Condition): + simplified = while_changing( + node, + lambda e: self._simplify( + e, constant_propagation, coalesce_simplification + ), + ) + + if node is expression: + expression = simplified + elif isinstance(node, exp.Where): + wheres.append(node) + elif isinstance(node, exp.Join): + # snowflake match_conditions have very strict ordering rules + if match := node.args.get("match_condition"): + match.meta[FINAL] = True + + joins.append(node) + + for where in wheres: + if always_true(where.this): + where.pop() + for join in joins: + if ( + always_true(join.args.get("on")) + and not join.args.get("using") + and not join.args.get("method") + and (join.side, join.kind) in self.JOINS + ): + join.args["on"].pop() + join.set("side", None) + join.set("kind", "CROSS") + + return expression + + def _simplify( + self, + expression: exp.Expression, + constant_propagation: bool, + coalesce_simplification: bool, + ): + pre_transformation_stack = [expression] + post_transformation_stack = [] + + while pre_transformation_stack: + original = pre_transformation_stack.pop() + node = original + + if not isinstance(node, SIMPLIFIABLE): + if isinstance(node, exp.Query): + self.simplify(node, constant_propagation, coalesce_simplification) + continue + + parent = node.parent + root = node is expression + + node = self.rewrite_between(node) + node = self.uniq_sort(node, root) + node = self.absorb_and_eliminate(node, root) + node = self.simplify_concat(node) + node = self.simplify_conditionals(node) + + if constant_propagation: + node = propagate_constants(node, root) + + if node is not original: + original.replace(node) + + for n in node.iter_expressions(reverse=True): + if n.meta.get(FINAL): + raise + pre_transformation_stack.extend( + n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL) + ) + post_transformation_stack.append((node, parent)) + + while post_transformation_stack: + original, parent = post_transformation_stack.pop() + root = original is expression + + # Resets parent, arg_key, index pointers– this is needed because some of the + # previous transformations mutate the AST, leading to an inconsistent state + for k, v in tuple(original.args.items()): + original.set(k, v) + + # Post-order transformations + node = self.simplify_not(original) + node = flatten(node) + node = self.simplify_connectors(node, root) + node = self.remove_complements(node, root) + + if coalesce_simplification: + node = self.simplify_coalesce(node) + node.parent = parent + + node = self.simplify_literals(node, root) + node = self.simplify_equality(node) + node = simplify_parens(node, dialect=self.dialect) + node = self.simplify_datetrunc(node) + node = self.sort_comparison(node) + node = self.simplify_startswith(node) + + if node is not original: + original.replace(node) + + return node + + @annotate_types_on_change + def rewrite_between(self, expression: exp.Expression) -> exp.Expression: + """Rewrite x between y and z to x >= y AND x <= z. + + This is done because comparison simplification is only done on lt/lte/gt/gte. + """ + if isinstance(expression, exp.Between): + negate = isinstance(expression.parent, exp.Not) + + expression = exp.and_( + exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), + exp.LTE( + this=expression.this.copy(), expression=expression.args["high"] + ), + copy=False, + ) + + if negate: + expression = exp.paren(expression, copy=False) + + return expression + + @annotate_types_on_change + def simplify_not(self, expression: exp.Expression) -> exp.Expression: + """ + Demorgan's Law + NOT (x OR y) -> NOT x AND NOT y + NOT (x AND y) -> NOT x OR NOT y + """ + if isinstance(expression, exp.Not): + this = expression.this + if is_null(this): + return exp.and_(exp.null(), exp.true(), copy=False) + if this.__class__ in self.COMPLEMENT_COMPARISONS: + right = this.expression + complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get( + right.__class__ + ) + if complement_subquery_predicate: + right = complement_subquery_predicate(this=right.this) + + return self.COMPLEMENT_COMPARISONS[this.__class__]( + this=this.this, expression=right + ) + if isinstance(this, exp.Paren): + condition = this.unnest() + if isinstance(condition, exp.And): + return exp.paren( + exp.or_( + exp.not_(condition.left, copy=False), + exp.not_(condition.right, copy=False), + copy=False, + ), + copy=False, + ) + if isinstance(condition, exp.Or): + return exp.paren( + exp.and_( + exp.not_(condition.left, copy=False), + exp.not_(condition.right, copy=False), + copy=False, + ), + copy=False, + ) + if is_null(condition): + return exp.and_(exp.null(), exp.true(), copy=False) + if always_true(this): + return exp.false() + if is_false(this): + return exp.true() + if ( + isinstance(this, exp.Not) + and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION + ): + inner = this.this + if inner.is_type(exp.DataType.Type.BOOLEAN): + # double negation + # NOT NOT x -> x, if x is BOOLEAN type + return inner + return expression + + @annotate_types_on_change + def simplify_connectors(self, expression, root=True): + def _simplify_connectors(expression, left, right): + if isinstance(expression, exp.And): + if is_false(left) or is_false(right): + return exp.false() + if is_zero(left) or is_zero(right): + return exp.false() + if ( + (is_null(left) and is_null(right)) + or (is_null(left) and always_true(right)) + or (always_true(left) and is_null(right)) + ): + return exp.null() + if always_true(left) and always_true(right): + return exp.true() + if always_true(left): + return right + if always_true(right): + return left + return self._simplify_comparison(expression, left, right) + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return exp.true() + if ( + (is_null(left) and is_null(right)) + or (is_null(left) and always_false(right)) + or (always_false(left) and is_null(right)) + ): + return exp.null() + if is_false(left): + return right + if is_false(right): + return left + return self._simplify_comparison(expression, left, right, or_=True) + + if isinstance(expression, exp.Connector): + original_parent = expression.parent + expression = self._flat_simplify(expression, _simplify_connectors, root) + + # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need + # to ensure that the resulting type is boolean. We know this is true only for connectors, + # boolean values and columns that are essentially operands to a connector: + # + # A AND (((B))) + # ~ this is safe to keep because it will eventually be part of another connector + if not isinstance( + expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT + ) and not expression.is_type(exp.DataType.Type.BOOLEAN): + while True: + if isinstance(original_parent, exp.Connector): + break + if not isinstance(original_parent, exp.Paren): + expression = expression.and_(exp.true(), copy=False) + break + + original_parent = original_parent.parent + + return expression + + @annotate_types_on_change + def _simplify_comparison(self, expression, left, right, or_=False): + if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS): + ll, lr = left.args.values() + rl, rr = right.args.values() + + largs = {ll, lr} + rargs = {rl, rr} + + matching = largs & rargs + columns = { + m + for m in matching + if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC) + } + + if matching and columns: + try: + l0 = first(largs - columns) + r = first(rargs - columns) + except StopIteration: + return expression + + if l0.is_number and r.is_number: + l0 = l0.to_py() + r = r.to_py() + elif l0.is_string and r.is_string: + l0 = l0.name + r = r.name + else: + l0 = extract_date(l0) + if not l0: + return None + r = extract_date(r) + if not r: + return None + # python won't compare date and datetime, but many engines will upcast + l0, r = cast_as_datetime(l0), cast_as_datetime(r) + + for (a, av), (b, bv) in itertools.permutations( + ((left, l0), (right, r)) + ): + if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE): + return left if (av > bv if or_ else av <= bv) else right + if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE): + return left if (av < bv if or_ else av >= bv) else right + + # we can't ever shortcut to true because the column could be null + if not or_: + if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE): + if av <= bv: + return exp.false() + elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE): + if av >= bv: + return exp.false() + elif isinstance(a, exp.EQ): + if isinstance(b, exp.LT): + return exp.false() if av >= bv else a + if isinstance(b, exp.LTE): + return exp.false() if av > bv else a + if isinstance(b, exp.GT): + return exp.false() if av <= bv else a + if isinstance(b, exp.GTE): + return exp.false() if av < bv else a + if isinstance(b, exp.NEQ): + return exp.false() if av == bv else a + return None + + @annotate_types_on_change + def remove_complements(self, expression, root=True): + """ + Removing complements. + + A AND NOT A -> FALSE (only for non-NULL A) + A OR NOT A -> TRUE (only for non-NULL A) + """ + if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): + ops = set(expression.flatten()) + for op in ops: + if isinstance(op, exp.Not) and op.this in ops: + if expression.meta.get("nonnull") is True: + return ( + exp.false() + if isinstance(expression, exp.And) + else exp.true() + ) + + return expression + + @annotate_types_on_change + def uniq_sort(self, expression, root=True): + """ + Uniq and sort a connector. + + C AND A AND B AND B -> A AND B AND C + """ + if isinstance(expression, exp.Connector) and ( + root or not expression.same_parent + ): + flattened = tuple(expression.flatten()) + + if isinstance(expression, exp.Xor): + result_func = exp.xor + # Do not deduplicate XOR as A XOR A != A if A == True + deduped = None + arr = tuple((gen(e), e) for e in flattened) + else: + result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ + deduped = {gen(e): e for e in flattened} + arr = tuple(deduped.items()) + + # check if the operands are already sorted, if not sort them + # A AND C AND B -> A AND B AND C + for i, (sql, e) in enumerate(arr[1:]): + if sql < arr[i][0]: + expression = result_func(*(e for _, e in sorted(arr)), copy=False) + break + else: + # we didn't have to sort but maybe we need to dedup + if deduped and len(deduped) < len(flattened): + unique_operand = flattened[0] + if len(deduped) == 1: + expression = unique_operand.and_(exp.true(), copy=False) + else: + expression = result_func(*deduped.values(), copy=False) + + return expression + + @annotate_types_on_change + def absorb_and_eliminate(self, expression, root=True): + """ + absorption: + A AND (A OR B) -> A + A OR (A AND B) -> A + A AND (NOT A OR B) -> A AND B + A OR (NOT A AND B) -> A OR B + elimination: + (A AND B) OR (A AND NOT B) -> A + (A OR B) AND (A OR NOT B) -> A + """ + if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): + kind = exp.Or if isinstance(expression, exp.And) else exp.And + + ops = tuple(expression.flatten()) + + # Initialize lookup tables: + # Set of all operands, used to find complements for absorption. + op_set = set() + # Sub-operands, used to find subsets for absorption. + subops = defaultdict(list) + # Pairs of complements, used for elimination. + pairs = defaultdict(list) + + # Populate the lookup tables + for op in ops: + op_set.add(op) + + if not isinstance(op, kind): + # In cases like: A OR (A AND B) + # Subop will be: ^ + subops[op].append({op}) + continue + + # In cases like: (A AND B) OR (A AND B AND C) + # Subops will be: ^ ^ + subset = set(op.flatten()) + for i in subset: + subops[i].append(subset) + + a, b = op.unnest_operands() + if isinstance(a, exp.Not): + pairs[frozenset((a.this, b))].append((op, b)) + if isinstance(b, exp.Not): + pairs[frozenset((a, b.this))].append((op, a)) + + for op in ops: + if not isinstance(op, kind): + continue + + a, b = op.unnest_operands() + + # Absorb + if isinstance(a, exp.Not) and a.this in op_set: + a.replace(exp.true() if kind == exp.And else exp.false()) + continue + if isinstance(b, exp.Not) and b.this in op_set: + b.replace(exp.true() if kind == exp.And else exp.false()) + continue + superset = set(op.flatten()) + if any( + any(subset < superset for subset in subops[i]) for i in superset + ): + op.replace(exp.false() if kind == exp.And else exp.true()) + continue + + # Eliminate + for other, complement in pairs[frozenset((a, b))]: + op.replace(complement) + other.replace(complement) + + return expression + + @annotate_types_on_change + @catch(ModuleNotFoundError, UnsupportedUnit) + def simplify_equality(self, expression: exp.Expression) -> exp.Expression: + """ + Use the subtraction and addition properties of equality to simplify expressions: + + x + 1 = 3 becomes x = 2 + + There are two binary operations in the above expression: + and = + Here's how we reference all the operands in the code below: + + l r + x + 1 = 3 + a b + """ + if isinstance(expression, self.COMPARISONS): + ll, r = expression.left, expression.right + + if ll.__class__ not in self.INVERSE_OPS: + return expression + + if r.is_number: + a_predicate = _is_number + b_predicate = _is_number + elif _is_date_literal(r): + a_predicate = _is_date_literal + b_predicate = _is_interval + else: + return expression + + if ll.__class__ in self.INVERSE_DATE_OPS: + ll = t.cast(exp.IntervalOp, ll) + a = ll.this + b = ll.interval() + else: + ll = t.cast(exp.Binary, ll) + a, b = ll.left, ll.right + + if not a_predicate(a) and b_predicate(b): + pass + elif not a_predicate(b) and b_predicate(a): + a, b = b, a + else: + return expression + + return expression.__class__( + this=a, expression=self.INVERSE_OPS[ll.__class__](this=r, expression=b) + ) + return expression + + @annotate_types_on_change + def simplify_literals(self, expression, root=True): + if isinstance(expression, exp.Binary) and not isinstance( + expression, exp.Connector + ): + return self._flat_simplify(expression, self._simplify_binary, root) + + if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): + return expression.this.this + + if type(expression) in self.INVERSE_DATE_OPS: + return ( + self._simplify_binary( + expression, expression.this, expression.interval() + ) + or expression + ) + + return expression + + def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression: + if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): + this = self._simplify_integer_cast(expr.this) + else: + this = expr.this + + if isinstance(expr, exp.Cast) and this.is_int: + num = this.to_py() + + # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any + # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is + # engine-dependent + if ( + self.TINYINT_MIN <= num <= self.TINYINT_MAX + and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES + ) or ( + self.UTINYINT_MIN <= num <= self.UTINYINT_MAX + and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES + ): + return this + + return expr + + def _simplify_binary(self, expression, a, b): + if isinstance(expression, self.COMPARISONS): + a = self._simplify_integer_cast(a) + b = self._simplify_integer_cast(b) + + if isinstance(expression, exp.Is): + if isinstance(b, exp.Not): + c = b.this + not_ = True + else: + c = b + not_ = False + + if is_null(c): + if isinstance(a, exp.Literal): + return exp.true() if not_ else exp.false() + if is_null(a): + return exp.false() if not_ else exp.true() + elif isinstance(expression, self.NULL_OK): + return None + elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If): + return exp.null() + + if a.is_number and b.is_number: + num_a = a.to_py() + num_b = b.to_py() + + if isinstance(expression, exp.Add): + return exp.Literal.number(num_a + num_b) + if isinstance(expression, exp.Mul): + return exp.Literal.number(num_a * num_b) + + # We only simplify Sub, Div if a and b have the same parent because they're not associative + if isinstance(expression, exp.Sub): + return ( + exp.Literal.number(num_a - num_b) if a.parent is b.parent else None + ) + if isinstance(expression, exp.Div): + # engines have differing int div behavior so intdiv is not safe + if ( + isinstance(num_a, int) and isinstance(num_b, int) + ) or a.parent is not b.parent: + return None + return exp.Literal.number(num_a / num_b) + + boolean = eval_boolean(expression, num_a, num_b) + + if boolean: + return boolean + elif a.is_string and b.is_string: + boolean = eval_boolean(expression, a.this, b.this) + + if boolean: + return boolean + elif _is_date_literal(a) and isinstance(b, exp.Interval): + date, b = extract_date(a), extract_interval(b) + if date and b: + if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): + return date_literal(date + b, extract_type(a)) + if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): + return date_literal(date - b, extract_type(a)) + elif isinstance(a, exp.Interval) and _is_date_literal(b): + a, date = extract_interval(a), extract_date(b) + # you cannot subtract a date from an interval + if a and b and isinstance(expression, exp.Add): + return date_literal(a + date, extract_type(b)) + elif _is_date_literal(a) and _is_date_literal(b): + if isinstance(expression, exp.Predicate): + a, b = extract_date(a), extract_date(b) + boolean = eval_boolean(expression, a, b) + if boolean: + return boolean + + return None + + @annotate_types_on_change + def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression: + # COALESCE(x) -> x + if ( + isinstance(expression, exp.Coalesce) + and (not expression.expressions or _is_nonnull_constant(expression.this)) + # COALESCE is also used as a Spark partitioning hint + and not isinstance(expression.parent, exp.Hint) + ): + return expression.this + + if self.dialect.COALESCE_COMPARISON_NON_STANDARD: + return expression + + if not isinstance(expression, self.COMPARISONS): + return expression + + if isinstance(expression.left, exp.Coalesce): + coalesce = expression.left + other = expression.right + elif isinstance(expression.right, exp.Coalesce): + coalesce = expression.right + other = expression.left + else: + return expression + + # This transformation is valid for non-constants, + # but it really only does anything if they are both constants. + if not _is_constant(other): + return expression + + # Find the first constant arg + for arg_index, arg in enumerate(coalesce.expressions): + if _is_constant(arg): + break + else: + return expression + + coalesce.set("expressions", coalesce.expressions[:arg_index]) + + # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, + # since we already remove COALESCE at the top of this function. + coalesce = coalesce if coalesce.expressions else coalesce.this + + # This expression is more complex than when we started, but it will get simplified further + return exp.paren( + exp.or_( + exp.and_( + coalesce.is_(exp.null()).not_(copy=False), + expression.copy(), + copy=False, + ), + exp.and_( + coalesce.is_(exp.null()), + type(expression)(this=arg.copy(), expression=other.copy()), + copy=False, + ), + copy=False, + ), + copy=False, + ) + + @annotate_types_on_change + def simplify_concat(self, expression): + """Reduces all groups that contain string literals by concatenating them.""" + if not isinstance(expression, self.CONCATS) or ( + # We can't reduce a CONCAT_WS call if we don't statically know the separator + isinstance(expression, exp.ConcatWs) + and not expression.expressions[0].is_string + ): + return expression + + if isinstance(expression, exp.ConcatWs): + sep_expr, *expressions = expression.expressions + sep = sep_expr.name + concat_type = exp.ConcatWs + args = {} + else: + expressions = expression.expressions + sep = "" + concat_type = exp.Concat + args = { + "safe": expression.args.get("safe"), + "coalesce": expression.args.get("coalesce"), + } + + new_args = [] + for is_string_group, group in itertools.groupby( + expressions or expression.flatten(), lambda e: e.is_string + ): + if is_string_group: + new_args.append( + exp.Literal.string(sep.join(string.name for string in group)) + ) + else: + new_args.extend(group) + + if len(new_args) == 1 and new_args[0].is_string: + return new_args[0] + + if concat_type is exp.ConcatWs: + new_args = [sep_expr] + new_args + elif isinstance(expression, exp.DPipe): + return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) + + return concat_type(expressions=new_args, **args) + + @annotate_types_on_change + def simplify_conditionals(self, expression): + """Simplifies expressions like IF, CASE if their condition is statically known.""" + if isinstance(expression, exp.Case): + this = expression.this + for case in expression.args["ifs"]: + cond = case.this + if this: + # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... + cond = cond.replace(this.pop().eq(cond)) + + if always_true(cond): + return case.args["true"] + + if always_false(cond): + case.pop() + if not expression.args["ifs"]: + return expression.args.get("default") or exp.null() + elif isinstance(expression, exp.If) and not isinstance( + expression.parent, exp.Case + ): + if always_true(expression.this): + return expression.args["true"] + if always_false(expression.this): + return expression.args.get("false") or exp.null() + + return expression + + @annotate_types_on_change + def simplify_startswith(self, expression: exp.Expression) -> exp.Expression: + """ + Reduces a prefix check to either TRUE or FALSE if both the string and the + prefix are statically known. + + Example: + >>> from bigframes_vendored.sqlglot import parse_one + >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() + 'TRUE' + """ + if ( + isinstance(expression, exp.StartsWith) + and expression.this.is_string + and expression.expression.is_string + ): + return exp.convert(expression.name.startswith(expression.expression.name)) + + return expression + + def _is_datetrunc_predicate( + self, left: exp.Expression, right: exp.Expression + ) -> bool: + return isinstance(left, self.DATETRUNCS) and _is_date_literal(right) + + @annotate_types_on_change + @catch(ModuleNotFoundError, UnsupportedUnit) + def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression: + """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" + comparison = expression.__class__ + + if isinstance(expression, self.DATETRUNCS): + this = expression.this + trunc_type = extract_type(this) + date = extract_date(this) + if date and expression.unit: + return date_literal( + date_floor(date, expression.unit.name.lower(), self.dialect), + trunc_type, + ) + elif comparison not in self.DATETRUNC_COMPARISONS: + return expression + + if isinstance(expression, exp.Binary): + ll, r = expression.left, expression.right + + if not self._is_datetrunc_predicate(ll, r): + return expression + + ll = t.cast(exp.DateTrunc, ll) + trunc_arg = ll.this + unit = ll.unit.name.lower() + date = extract_date(r) + + if not date: + return expression + + return ( + self.DATETRUNC_BINARY_COMPARISONS[comparison]( + trunc_arg, date, unit, self.dialect, extract_type(r) + ) + or expression + ) + + if isinstance(expression, exp.In): + ll = expression.this + rs = expression.expressions + + if rs and all(self._is_datetrunc_predicate(ll, r) for r in rs): + ll = t.cast(exp.DateTrunc, ll) + unit = ll.unit.name.lower() + + ranges = [] + for r in rs: + date = extract_date(r) + if not date: + return expression + drange = _datetrunc_range(date, unit, self.dialect) + if drange: + ranges.append(drange) + + if not ranges: + return expression + + ranges = merge_ranges(ranges) + target_type = extract_type(*rs) + + return exp.or_( + *[ + _datetrunc_eq_expression(ll, drange, target_type) + for drange in ranges + ], + copy=False, + ) + + return expression + + @annotate_types_on_change + def sort_comparison(self, expression: exp.Expression) -> exp.Expression: + if expression.__class__ in self.COMPLEMENT_COMPARISONS: + l, r = expression.this, expression.expression + l_column = isinstance(l, exp.Column) + r_column = isinstance(r, exp.Column) + l_const = _is_constant(l) + r_const = _is_constant(r) + + if ( + (l_column and not r_column) + or (r_const and not l_const) + or isinstance(r, exp.SubqueryPredicate) + ): + return expression + if ( + (r_column and not l_column) + or (l_const and not r_const) + or (gen(l) > gen(r)) + ): + return self.INVERSE_COMPARISONS.get( + expression.__class__, expression.__class__ + )(this=r, expression=l) + return expression + + def _flat_simplify(self, expression, simplifier, root=True): + if root or not expression.same_parent: + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) + + while queue: + a = queue.popleft() + + for b in queue: + result = simplifier(expression, a, b) + + if result and result is not expression: + queue.remove(b) + queue.appendleft(result) + break + else: + operands.append(a) + + if len(operands) < size: + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) + return expression + + +def gen(expression: t.Any, comments: bool = False) -> str: + """Simple pseudo sql generator for quickly generating sortable and uniq strings. + + Sorting and deduping sql is a necessary step for optimization. Calling the actual + generator is expensive so we have a bare minimum sql generator here. + + Args: + expression: the expression to convert into a SQL string. + comments: whether to include the expression's comments. + """ + return Gen().gen(expression, comments=comments) + + +class Gen: + def __init__(self): + self.stack = [] + self.sqls = [] + + def gen(self, expression: exp.Expression, comments: bool = False) -> str: + self.stack = [expression] + self.sqls.clear() + + while self.stack: + node = self.stack.pop() + + if isinstance(node, exp.Expression): + if comments and node.comments: + self.stack.append(f" /*{','.join(node.comments)}*/") + + exp_handler_name = f"{node.key}_sql" + + if hasattr(self, exp_handler_name): + getattr(self, exp_handler_name)(node) + elif isinstance(node, exp.Func): + self._function(node) + else: + key = node.key.upper() + self.stack.append(f"{key} " if self._args(node) else key) + elif type(node) is list: + for n in reversed(node): + if n is not None: + self.stack.extend((n, ",")) + if node: + self.stack.pop() + else: + if node is not None: + self.sqls.append(str(node)) + + return "".join(self.sqls) + + def add_sql(self, e: exp.Add) -> None: + self._binary(e, " + ") + + def alias_sql(self, e: exp.Alias) -> None: + self.stack.extend( + ( + e.args.get("alias"), + " AS ", + e.args.get("this"), + ) + ) + + def and_sql(self, e: exp.And) -> None: + self._binary(e, " AND ") + + def anonymous_sql(self, e: exp.Anonymous) -> None: + this = e.this + if isinstance(this, str): + name = this.upper() + elif isinstance(this, exp.Identifier): + name = this.this + name = f'"{name}"' if this.quoted else name.upper() + else: + raise ValueError( + f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." + ) + + self.stack.extend( + ( + ")", + e.expressions, + "(", + name, + ) + ) + + def between_sql(self, e: exp.Between) -> None: + self.stack.extend( + ( + e.args.get("high"), + " AND ", + e.args.get("low"), + " BETWEEN ", + e.this, + ) + ) + + def boolean_sql(self, e: exp.Boolean) -> None: + self.stack.append("TRUE" if e.this else "FALSE") + + def bracket_sql(self, e: exp.Bracket) -> None: + self.stack.extend( + ( + "]", + e.expressions, + "[", + e.this, + ) + ) + + def column_sql(self, e: exp.Column) -> None: + for p in reversed(e.parts): + self.stack.extend((p, ".")) + self.stack.pop() + + def datatype_sql(self, e: exp.DataType) -> None: + self._args(e, 1) + self.stack.append(f"{e.this.name} ") + + def div_sql(self, e: exp.Div) -> None: + self._binary(e, " / ") + + def dot_sql(self, e: exp.Dot) -> None: + self._binary(e, ".") + + def eq_sql(self, e: exp.EQ) -> None: + self._binary(e, " = ") + + def from_sql(self, e: exp.From) -> None: + self.stack.extend((e.this, "FROM ")) + + def gt_sql(self, e: exp.GT) -> None: + self._binary(e, " > ") + + def gte_sql(self, e: exp.GTE) -> None: + self._binary(e, " >= ") + + def identifier_sql(self, e: exp.Identifier) -> None: + self.stack.append(f'"{e.this}"' if e.quoted else e.this) + + def ilike_sql(self, e: exp.ILike) -> None: + self._binary(e, " ILIKE ") + + def in_sql(self, e: exp.In) -> None: + self.stack.append(")") + self._args(e, 1) + self.stack.extend( + ( + "(", + " IN ", + e.this, + ) + ) + + def intdiv_sql(self, e: exp.IntDiv) -> None: + self._binary(e, " DIV ") + + def is_sql(self, e: exp.Is) -> None: + self._binary(e, " IS ") + + def like_sql(self, e: exp.Like) -> None: + self._binary(e, " Like ") + + def literal_sql(self, e: exp.Literal) -> None: + self.stack.append(f"'{e.this}'" if e.is_string else e.this) + + def lt_sql(self, e: exp.LT) -> None: + self._binary(e, " < ") + + def lte_sql(self, e: exp.LTE) -> None: + self._binary(e, " <= ") + + def mod_sql(self, e: exp.Mod) -> None: + self._binary(e, " % ") + + def mul_sql(self, e: exp.Mul) -> None: + self._binary(e, " * ") + + def neg_sql(self, e: exp.Neg) -> None: + self._unary(e, "-") + + def neq_sql(self, e: exp.NEQ) -> None: + self._binary(e, " <> ") + + def not_sql(self, e: exp.Not) -> None: + self._unary(e, "NOT ") + + def null_sql(self, e: exp.Null) -> None: + self.stack.append("NULL") + + def or_sql(self, e: exp.Or) -> None: + self._binary(e, " OR ") + + def paren_sql(self, e: exp.Paren) -> None: + self.stack.extend( + ( + ")", + e.this, + "(", + ) + ) + + def sub_sql(self, e: exp.Sub) -> None: + self._binary(e, " - ") + + def subquery_sql(self, e: exp.Subquery) -> None: + self._args(e, 2) + alias = e.args.get("alias") + if alias: + self.stack.append(alias) + self.stack.extend((")", e.this, "(")) + + def table_sql(self, e: exp.Table) -> None: + self._args(e, 4) + alias = e.args.get("alias") + if alias: + self.stack.append(alias) + for p in reversed(e.parts): + self.stack.extend((p, ".")) + self.stack.pop() + + def tablealias_sql(self, e: exp.TableAlias) -> None: + columns = e.columns + + if columns: + self.stack.extend((")", columns, "(")) + + self.stack.extend((e.this, " AS ")) + + def var_sql(self, e: exp.Var) -> None: + self.stack.append(e.this) + + def _binary(self, e: exp.Binary, op: str) -> None: + self.stack.extend((e.expression, op, e.this)) + + def _unary(self, e: exp.Unary, op: str) -> None: + self.stack.extend((e.this, op)) + + def _function(self, e: exp.Func) -> None: + self.stack.extend( + ( + ")", + list(e.args.values()), + "(", + e.sql_name(), + ) + ) + + def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: + kvs = [] + arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types + + for k in arg_types: + v = node.args.get(k) + + if v is not None: + kvs.append([f":{k}", v]) + if kvs: + self.stack.append(kvs) + return True + return False diff --git a/third_party/bigframes_vendored/sqlglot/optimizer/unnest_subqueries.py b/third_party/bigframes_vendored/sqlglot/optimizer/unnest_subqueries.py new file mode 100644 index 0000000000..acae6d8f1e --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/optimizer/unnest_subqueries.py @@ -0,0 +1,329 @@ +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.helper import name_sequence +from bigframes_vendored.sqlglot.optimizer.scope import ( + find_in_scope, + ScopeType, + traverse_scope, +) + + +def unnest_subqueries(expression): + """ + Rewrite sqlglot AST to convert some predicates with subqueries into joins. + + Convert scalar subqueries into cross joins. + Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") + >>> unnest_subqueries(expression).sql() + 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' + + Args: + expression (sqlglot.Expression): expression to unnest + Returns: + sqlglot.Expression: unnested expression + """ + next_alias_name = name_sequence("_u_") + + for scope in traverse_scope(expression): + select = scope.expression + parent = select.parent_select + if not parent: + continue + if scope.external_columns: + decorrelate(select, parent, scope.external_columns, next_alias_name) + elif scope.scope_type == ScopeType.SUBQUERY: + unnest(select, parent, next_alias_name) + + return expression + + +def unnest(select, parent_select, next_alias_name): + if len(select.selects) > 1: + return + + predicate = select.find_ancestor(exp.Condition) + if ( + not predicate + or parent_select is not predicate.parent_select + or not parent_select.args.get("from_") + ): + return + + if isinstance(select, exp.SetOperation): + select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) + + alias = next_alias_name() + clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) + + # This subquery returns a scalar and can just be converted to a cross join + if not isinstance(predicate, (exp.In, exp.Any)): + column = exp.column(select.selects[0].alias_or_name, alias) + + clause_parent_select = clause.parent_select if clause else None + + if ( + isinstance(clause, exp.Having) and clause_parent_select is parent_select + ) or ( + (not clause or clause_parent_select is not parent_select) + and ( + parent_select.args.get("group") + or any( + find_in_scope(select, exp.AggFunc) + for select in parent_select.selects + ) + ) + ): + column = exp.Max(this=column) + elif not isinstance(select.parent, exp.Subquery): + return + + join_type = "CROSS" + on_clause = None + if isinstance(predicate, exp.Exists): + # If a subquery returns no rows, cross-joining against it incorrectly eliminates all rows + # from the parent query. Therefore, we use a LEFT JOIN that always matches (ON TRUE), then + # check for non-NULL column values to determine whether the subquery contained rows. + column = column.is_(exp.null()).not_() + join_type = "LEFT" + on_clause = exp.true() + + _replace(select.parent, column) + parent_select.join( + select, on=on_clause, join_type=join_type, join_alias=alias, copy=False + ) + return + + if select.find(exp.Limit, exp.Offset): + return + + if isinstance(predicate, exp.Any): + predicate = predicate.find_ancestor(exp.EQ) + + if not predicate or parent_select is not predicate.parent_select: + return + + column = _other_operand(predicate) + value = select.selects[0] + + join_key = exp.column(value.alias, alias) + join_key_not_null = join_key.is_(exp.null()).not_() + + if isinstance(clause, exp.Join): + _replace(predicate, exp.true()) + parent_select.where(join_key_not_null, copy=False) + else: + _replace(predicate, join_key_not_null) + + group = select.args.get("group") + + if group: + if {value.this} != set(group.expressions): + select = ( + exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias)) + .from_(select.subquery("_q", copy=False), copy=False) + .group_by(exp.column(value.alias, "_q"), copy=False) + ) + elif not find_in_scope(value.this, exp.AggFunc): + select = select.group_by(value.this, copy=False) + + parent_select.join( + select, + on=column.eq(join_key), + join_type="LEFT", + join_alias=alias, + copy=False, + ) + + +def decorrelate(select, parent_select, external_columns, next_alias_name): + where = select.args.get("where") + + if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): + return + + table_alias = next_alias_name() + keys = [] + + # for all external columns in the where statement, find the relevant predicate + # keys to convert it into a join + for column in external_columns: + if column.find_ancestor(exp.Where) is not where: + return + + predicate = column.find_ancestor(exp.Predicate) + + if not predicate or predicate.find_ancestor(exp.Where) is not where: + return + + if isinstance(predicate, exp.Binary): + key = ( + predicate.right + if any(node is column for node in predicate.left.walk()) + else predicate.left + ) + else: + return + + keys.append((key, column, predicate)) + + if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): + return + + is_subquery_projection = any( + node is select.parent + for node in map(lambda s: s.unalias(), parent_select.selects) + if isinstance(node, exp.Subquery) + ) + + value = select.selects[0] + key_aliases = {} + group_by = [] + + for key, _, predicate in keys: + # if we filter on the value of the subquery, it needs to be unique + if key == value.this: + key_aliases[key] = value.alias + group_by.append(key) + else: + if key not in key_aliases: + key_aliases[key] = next_alias_name() + # all predicates that are equalities must also be in the unique + # so that we don't do a many to many join + if isinstance(predicate, exp.EQ) and key not in group_by: + group_by.append(key) + + parent_predicate = select.find_ancestor(exp.Predicate) + + # if the value of the subquery is not an agg or a key, we need to collect it into an array + # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. + agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg + if not value.find(exp.AggFunc) and value.this not in group_by: + select.select( + exp.alias_(agg_func(this=value.this), value.alias, quoted=False), + append=False, + copy=False, + ) + + # exists queries should not have any selects as it only checks if there are any rows + # all selects will be added by the optimizer and only used for join keys + if isinstance(parent_predicate, exp.Exists): + select.set("expressions", []) + + for key, alias in key_aliases.items(): + if key in group_by: + # add all keys to the projections of the subquery + # so that we can use it as a join key + if isinstance(parent_predicate, exp.Exists) or key != value.this: + select.select(f"{key} AS {alias}", copy=False) + else: + select.select( + exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False + ) + + alias = exp.column(value.alias, table_alias) + other = _other_operand(parent_predicate) + op_type = type(parent_predicate.parent) if parent_predicate else None + + if isinstance(parent_predicate, exp.Exists): + alias = exp.column(list(key_aliases.values())[0], table_alias) + parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") + elif isinstance(parent_predicate, exp.All): + assert issubclass(op_type, exp.Binary) + predicate = op_type(this=other, expression=exp.column("_x")) + parent_predicate = _replace( + parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})" + ) + elif isinstance(parent_predicate, exp.Any): + assert issubclass(op_type, exp.Binary) + if value.this in group_by: + predicate = op_type(this=other, expression=alias) + parent_predicate = _replace(parent_predicate.parent, predicate) + else: + predicate = op_type(this=other, expression=exp.column("_x")) + parent_predicate = _replace( + parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})" + ) + elif isinstance(parent_predicate, exp.In): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, + f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", + ) + else: + if is_subquery_projection and select.parent.alias: + alias = exp.alias_(alias, select.parent.alias) + + # COUNT always returns 0 on empty datasets, so we need take that into consideration here + # by transforming all counts into 0 and using that as the coalesced value + if value.find(exp.Count): + + def remove_aggs(node): + if isinstance(node, exp.Count): + return exp.Literal.number(0) + elif isinstance(node, exp.AggFunc): + return exp.null() + return node + + alias = exp.Coalesce( + this=alias, expressions=[value.this.transform(remove_aggs)] + ) + + select.parent.replace(alias) + + for key, column, predicate in keys: + predicate.replace(exp.true()) + nested = exp.column(key_aliases[key], table_alias) + + if is_subquery_projection: + key.replace(nested) + if not isinstance(predicate, exp.EQ): + parent_select.where(predicate, copy=False) + continue + + if key in group_by: + key.replace(nested) + elif isinstance(predicate, exp.EQ): + parent_predicate = _replace( + parent_predicate, + f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", + ) + else: + key.replace(exp.to_identifier("_x")) + parent_predicate = _replace( + parent_predicate, + f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", + ) + + parent_select.join( + select.group_by(*group_by, copy=False), + on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], + join_type="LEFT", + join_alias=table_alias, + copy=False, + ) + + +def _replace(expression, condition): + return expression.replace(exp.condition(condition)) + + +def _other_operand(expression): + if isinstance(expression, exp.In): + return expression.this + + if isinstance(expression, (exp.Any, exp.All)): + return _other_operand(expression.parent) + + if isinstance(expression, exp.Binary): + return ( + expression.right + if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) + else expression.left + ) + + return None diff --git a/third_party/bigframes_vendored/sqlglot/parser.py b/third_party/bigframes_vendored/sqlglot/parser.py new file mode 100644 index 0000000000..2f93508293 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/parser.py @@ -0,0 +1,9712 @@ +from __future__ import annotations + +from collections import defaultdict +import itertools +import logging +import re +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.errors import ( + concat_messages, + ErrorLevel, + highlight_sql, + merge_errors, + ParseError, + TokenError, +) +from bigframes_vendored.sqlglot.helper import apply_index_offset, ensure_list, seq_get +from bigframes_vendored.sqlglot.time import format_time +from bigframes_vendored.sqlglot.tokens import Token, Tokenizer, TokenType +from bigframes_vendored.sqlglot.trie import in_trie, new_trie, TrieResult + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E, Lit + from bigframes_vendored.sqlglot.dialects.dialect import Dialect, DialectType + + T = t.TypeVar("T") + TCeilFloor = t.TypeVar("TCeilFloor", exp.Ceil, exp.Floor) + +logger = logging.getLogger("sqlglot") + +OPTIONS_TYPE = t.Dict[str, t.Sequence[t.Union[t.Sequence[str], str]]] + +# Used to detect alphabetical characters and +/- in timestamp literals +TIME_ZONE_RE: t.Pattern[str] = re.compile(r":.*?[a-zA-Z\+\-]") + + +def build_var_map(args: t.List) -> exp.StarMap | exp.VarMap: + if len(args) == 1 and args[0].is_star: + return exp.StarMap(this=args[0]) + + keys = [] + values = [] + for i in range(0, len(args), 2): + keys.append(args[i]) + values.append(args[i + 1]) + + return exp.VarMap( + keys=exp.array(*keys, copy=False), values=exp.array(*values, copy=False) + ) + + +def build_like(args: t.List) -> exp.Escape | exp.Like: + like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) + return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like + + +def binary_range_parser( + expr_type: t.Type[exp.Expression], reverse_args: bool = False +) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]: + def _parse_binary_range( + self: Parser, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + expression = self._parse_bitwise() + if reverse_args: + this, expression = expression, this + return self._parse_escape( + self.expression(expr_type, this=this, expression=expression) + ) + + return _parse_binary_range + + +def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func: + # Default argument order is base, expression + this = seq_get(args, 0) + expression = seq_get(args, 1) + + if expression: + if not dialect.LOG_BASE_FIRST: + this, expression = expression, this + return exp.Log(this=this, expression=expression) + + return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this) + + +def build_hex(args: t.List, dialect: Dialect) -> exp.Hex | exp.LowerHex: + arg = seq_get(args, 0) + return exp.LowerHex(this=arg) if dialect.HEX_LOWERCASE else exp.Hex(this=arg) + + +def build_lower(args: t.List) -> exp.Lower | exp.Hex: + # LOWER(HEX(..)) can be simplified to LowerHex to simplify its transpilation + arg = seq_get(args, 0) + return ( + exp.LowerHex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Lower(this=arg) + ) + + +def build_upper(args: t.List) -> exp.Upper | exp.Hex: + # UPPER(HEX(..)) can be simplified to Hex to simplify its transpilation + arg = seq_get(args, 0) + return exp.Hex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Upper(this=arg) + + +def build_extract_json_with_path( + expr_type: t.Type[E], +) -> t.Callable[[t.List, Dialect], E]: + def _builder(args: t.List, dialect: Dialect) -> E: + expression = expr_type( + this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) + ) + if len(args) > 2 and expr_type is exp.JSONExtract: + expression.set("expressions", args[2:]) + if expr_type is exp.JSONExtractScalar: + expression.set("scalar_only", dialect.JSON_EXTRACT_SCALAR_SCALAR_ONLY) + + return expression + + return _builder + + +def build_mod(args: t.List) -> exp.Mod: + this = seq_get(args, 0) + expression = seq_get(args, 1) + + # Wrap the operands if they are binary nodes, e.g. MOD(a + 1, 7) -> (a + 1) % 7 + this = exp.Paren(this=this) if isinstance(this, exp.Binary) else this + expression = ( + exp.Paren(this=expression) if isinstance(expression, exp.Binary) else expression + ) + + return exp.Mod(this=this, expression=expression) + + +def build_pad(args: t.List, is_left: bool = True): + return exp.Pad( + this=seq_get(args, 0), + expression=seq_get(args, 1), + fill_pattern=seq_get(args, 2), + is_left=is_left, + ) + + +def build_array_constructor( + exp_class: t.Type[E], args: t.List, bracket_kind: TokenType, dialect: Dialect +) -> exp.Expression: + array_exp = exp_class(expressions=args) + + if exp_class == exp.Array and dialect.HAS_DISTINCT_ARRAY_CONSTRUCTORS: + array_exp.set("bracket_notation", bracket_kind == TokenType.L_BRACKET) + + return array_exp + + +def build_convert_timezone( + args: t.List, default_source_tz: t.Optional[str] = None +) -> t.Union[exp.ConvertTimezone, exp.Anonymous]: + if len(args) == 2: + source_tz = exp.Literal.string(default_source_tz) if default_source_tz else None + return exp.ConvertTimezone( + source_tz=source_tz, target_tz=seq_get(args, 0), timestamp=seq_get(args, 1) + ) + + return exp.ConvertTimezone.from_arg_list(args) + + +def build_trim(args: t.List, is_left: bool = True): + return exp.Trim( + this=seq_get(args, 0), + expression=seq_get(args, 1), + position="LEADING" if is_left else "TRAILING", + ) + + +def build_coalesce( + args: t.List, is_nvl: t.Optional[bool] = None, is_null: t.Optional[bool] = None +) -> exp.Coalesce: + return exp.Coalesce( + this=seq_get(args, 0), expressions=args[1:], is_nvl=is_nvl, is_null=is_null + ) + + +def build_locate_strposition(args: t.List): + return exp.StrPosition( + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ) + + +class _Parser(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + klass.SHOW_TRIE = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) + klass.SET_TRIE = new_trie(key.split(" ") for key in klass.SET_PARSERS) + + return klass + + +class Parser(metaclass=_Parser): + """ + Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree. + + Args: + error_level: The desired error level. + Default: ErrorLevel.IMMEDIATE + error_message_context: The amount of context to capture from a query string when displaying + the error message (in number of characters). + Default: 100 + max_errors: Maximum number of error messages to include in a raised ParseError. + This is only relevant if error_level is ErrorLevel.RAISE. + Default: 3 + """ + + FUNCTIONS: t.Dict[str, t.Callable] = { + **{name: func.from_arg_list for name, func in exp.FUNCTION_BY_NAME.items()}, + **dict.fromkeys(("COALESCE", "IFNULL", "NVL"), build_coalesce), + "ARRAY": lambda args, dialect: exp.Array(expressions=args), + "ARRAYAGG": lambda args, dialect: exp.ArrayAgg( + this=seq_get(args, 0), + nulls_excluded=dialect.ARRAY_AGG_INCLUDES_NULLS is None or None, + ), + "ARRAY_AGG": lambda args, dialect: exp.ArrayAgg( + this=seq_get(args, 0), + nulls_excluded=dialect.ARRAY_AGG_INCLUDES_NULLS is None or None, + ), + "CHAR": lambda args: exp.Chr(expressions=args), + "CHR": lambda args: exp.Chr(expressions=args), + "COUNT": lambda args: exp.Count( + this=seq_get(args, 0), expressions=args[1:], big_int=True + ), + "CONCAT": lambda args, dialect: exp.Concat( + expressions=args, + safe=not dialect.STRICT_STRING_CONCAT, + coalesce=dialect.CONCAT_COALESCE, + ), + "CONCAT_WS": lambda args, dialect: exp.ConcatWs( + expressions=args, + safe=not dialect.STRICT_STRING_CONCAT, + coalesce=dialect.CONCAT_COALESCE, + ), + "CONVERT_TIMEZONE": build_convert_timezone, + "DATE_TO_DATE_STR": lambda args: exp.Cast( + this=seq_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + "GENERATE_DATE_ARRAY": lambda args: exp.GenerateDateArray( + start=seq_get(args, 0), + end=seq_get(args, 1), + step=seq_get(args, 2) + or exp.Interval(this=exp.Literal.string(1), unit=exp.var("DAY")), + ), + "GENERATE_UUID": lambda args, dialect: exp.Uuid( + is_string=dialect.UUID_IS_STRING_TYPE or None + ), + "GLOB": lambda args: exp.Glob( + this=seq_get(args, 1), expression=seq_get(args, 0) + ), + "GREATEST": lambda args, dialect: exp.Greatest( + this=seq_get(args, 0), + expressions=args[1:], + ignore_nulls=dialect.LEAST_GREATEST_IGNORES_NULLS, + ), + "LEAST": lambda args, dialect: exp.Least( + this=seq_get(args, 0), + expressions=args[1:], + ignore_nulls=dialect.LEAST_GREATEST_IGNORES_NULLS, + ), + "HEX": build_hex, + "JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract), + "JSON_EXTRACT_SCALAR": build_extract_json_with_path(exp.JSONExtractScalar), + "JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar), + "LIKE": build_like, + "LOG": build_logarithm, + "LOG2": lambda args: exp.Log( + this=exp.Literal.number(2), expression=seq_get(args, 0) + ), + "LOG10": lambda args: exp.Log( + this=exp.Literal.number(10), expression=seq_get(args, 0) + ), + "LOWER": build_lower, + "LPAD": lambda args: build_pad(args), + "LEFTPAD": lambda args: build_pad(args), + "LTRIM": lambda args: build_trim(args), + "MOD": build_mod, + "RIGHTPAD": lambda args: build_pad(args, is_left=False), + "RPAD": lambda args: build_pad(args, is_left=False), + "RTRIM": lambda args: build_trim(args, is_left=False), + "SCOPE_RESOLUTION": lambda args: exp.ScopeResolution( + expression=seq_get(args, 0) + ) + if len(args) != 2 + else exp.ScopeResolution(this=seq_get(args, 0), expression=seq_get(args, 1)), + "STRPOS": exp.StrPosition.from_arg_list, + "CHARINDEX": lambda args: build_locate_strposition(args), + "INSTR": exp.StrPosition.from_arg_list, + "LOCATE": lambda args: build_locate_strposition(args), + "TIME_TO_TIME_STR": lambda args: exp.Cast( + this=seq_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + "TO_HEX": build_hex, + "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring( + this=exp.Cast( + this=seq_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + start=exp.Literal.number(1), + length=exp.Literal.number(10), + ), + "UNNEST": lambda args: exp.Unnest(expressions=ensure_list(seq_get(args, 0))), + "UPPER": build_upper, + "UUID": lambda args, dialect: exp.Uuid( + is_string=dialect.UUID_IS_STRING_TYPE or None + ), + "VAR_MAP": build_var_map, + } + + NO_PAREN_FUNCTIONS = { + TokenType.CURRENT_DATE: exp.CurrentDate, + TokenType.CURRENT_DATETIME: exp.CurrentDate, + TokenType.CURRENT_TIME: exp.CurrentTime, + TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp, + TokenType.CURRENT_USER: exp.CurrentUser, + TokenType.LOCALTIME: exp.Localtime, + TokenType.LOCALTIMESTAMP: exp.Localtimestamp, + TokenType.CURRENT_ROLE: exp.CurrentRole, + } + + STRUCT_TYPE_TOKENS = { + TokenType.FILE, + TokenType.NESTED, + TokenType.OBJECT, + TokenType.STRUCT, + TokenType.UNION, + } + + NESTED_TYPE_TOKENS = { + TokenType.ARRAY, + TokenType.LIST, + TokenType.LOWCARDINALITY, + TokenType.MAP, + TokenType.NULLABLE, + TokenType.RANGE, + *STRUCT_TYPE_TOKENS, + } + + ENUM_TYPE_TOKENS = { + TokenType.DYNAMIC, + TokenType.ENUM, + TokenType.ENUM8, + TokenType.ENUM16, + } + + AGGREGATE_TYPE_TOKENS = { + TokenType.AGGREGATEFUNCTION, + TokenType.SIMPLEAGGREGATEFUNCTION, + } + + TYPE_TOKENS = { + TokenType.BIT, + TokenType.BOOLEAN, + TokenType.TINYINT, + TokenType.UTINYINT, + TokenType.SMALLINT, + TokenType.USMALLINT, + TokenType.INT, + TokenType.UINT, + TokenType.BIGINT, + TokenType.UBIGINT, + TokenType.BIGNUM, + TokenType.INT128, + TokenType.UINT128, + TokenType.INT256, + TokenType.UINT256, + TokenType.MEDIUMINT, + TokenType.UMEDIUMINT, + TokenType.FIXEDSTRING, + TokenType.FLOAT, + TokenType.DOUBLE, + TokenType.UDOUBLE, + TokenType.CHAR, + TokenType.NCHAR, + TokenType.VARCHAR, + TokenType.NVARCHAR, + TokenType.BPCHAR, + TokenType.TEXT, + TokenType.MEDIUMTEXT, + TokenType.LONGTEXT, + TokenType.BLOB, + TokenType.MEDIUMBLOB, + TokenType.LONGBLOB, + TokenType.BINARY, + TokenType.VARBINARY, + TokenType.JSON, + TokenType.JSONB, + TokenType.INTERVAL, + TokenType.TINYBLOB, + TokenType.TINYTEXT, + TokenType.TIME, + TokenType.TIMETZ, + TokenType.TIME_NS, + TokenType.TIMESTAMP, + TokenType.TIMESTAMP_S, + TokenType.TIMESTAMP_MS, + TokenType.TIMESTAMP_NS, + TokenType.TIMESTAMPTZ, + TokenType.TIMESTAMPLTZ, + TokenType.TIMESTAMPNTZ, + TokenType.DATETIME, + TokenType.DATETIME2, + TokenType.DATETIME64, + TokenType.SMALLDATETIME, + TokenType.DATE, + TokenType.DATE32, + TokenType.INT4RANGE, + TokenType.INT4MULTIRANGE, + TokenType.INT8RANGE, + TokenType.INT8MULTIRANGE, + TokenType.NUMRANGE, + TokenType.NUMMULTIRANGE, + TokenType.TSRANGE, + TokenType.TSMULTIRANGE, + TokenType.TSTZRANGE, + TokenType.TSTZMULTIRANGE, + TokenType.DATERANGE, + TokenType.DATEMULTIRANGE, + TokenType.DECIMAL, + TokenType.DECIMAL32, + TokenType.DECIMAL64, + TokenType.DECIMAL128, + TokenType.DECIMAL256, + TokenType.DECFLOAT, + TokenType.UDECIMAL, + TokenType.BIGDECIMAL, + TokenType.UUID, + TokenType.GEOGRAPHY, + TokenType.GEOGRAPHYPOINT, + TokenType.GEOMETRY, + TokenType.POINT, + TokenType.RING, + TokenType.LINESTRING, + TokenType.MULTILINESTRING, + TokenType.POLYGON, + TokenType.MULTIPOLYGON, + TokenType.HLLSKETCH, + TokenType.HSTORE, + TokenType.PSEUDO_TYPE, + TokenType.SUPER, + TokenType.SERIAL, + TokenType.SMALLSERIAL, + TokenType.BIGSERIAL, + TokenType.XML, + TokenType.YEAR, + TokenType.USERDEFINED, + TokenType.MONEY, + TokenType.SMALLMONEY, + TokenType.ROWVERSION, + TokenType.IMAGE, + TokenType.VARIANT, + TokenType.VECTOR, + TokenType.VOID, + TokenType.OBJECT, + TokenType.OBJECT_IDENTIFIER, + TokenType.INET, + TokenType.IPADDRESS, + TokenType.IPPREFIX, + TokenType.IPV4, + TokenType.IPV6, + TokenType.UNKNOWN, + TokenType.NOTHING, + TokenType.NULL, + TokenType.NAME, + TokenType.TDIGEST, + TokenType.DYNAMIC, + *ENUM_TYPE_TOKENS, + *NESTED_TYPE_TOKENS, + *AGGREGATE_TYPE_TOKENS, + } + + SIGNED_TO_UNSIGNED_TYPE_TOKEN = { + TokenType.BIGINT: TokenType.UBIGINT, + TokenType.INT: TokenType.UINT, + TokenType.MEDIUMINT: TokenType.UMEDIUMINT, + TokenType.SMALLINT: TokenType.USMALLINT, + TokenType.TINYINT: TokenType.UTINYINT, + TokenType.DECIMAL: TokenType.UDECIMAL, + TokenType.DOUBLE: TokenType.UDOUBLE, + } + + SUBQUERY_PREDICATES = { + TokenType.ANY: exp.Any, + TokenType.ALL: exp.All, + TokenType.EXISTS: exp.Exists, + TokenType.SOME: exp.Any, + } + + RESERVED_TOKENS = { + *Tokenizer.SINGLE_TOKENS.values(), + TokenType.SELECT, + } - {TokenType.IDENTIFIER} + + DB_CREATABLES = { + TokenType.DATABASE, + TokenType.DICTIONARY, + TokenType.FILE_FORMAT, + TokenType.MODEL, + TokenType.NAMESPACE, + TokenType.SCHEMA, + TokenType.SEMANTIC_VIEW, + TokenType.SEQUENCE, + TokenType.SINK, + TokenType.SOURCE, + TokenType.STAGE, + TokenType.STORAGE_INTEGRATION, + TokenType.STREAMLIT, + TokenType.TABLE, + TokenType.TAG, + TokenType.VIEW, + TokenType.WAREHOUSE, + } + + CREATABLES = { + TokenType.COLUMN, + TokenType.CONSTRAINT, + TokenType.FOREIGN_KEY, + TokenType.FUNCTION, + TokenType.INDEX, + TokenType.PROCEDURE, + *DB_CREATABLES, + } + + ALTERABLES = { + TokenType.INDEX, + TokenType.TABLE, + TokenType.VIEW, + TokenType.SESSION, + } + + # Tokens that can represent identifiers + ID_VAR_TOKENS = { + TokenType.ALL, + TokenType.ANALYZE, + TokenType.ATTACH, + TokenType.VAR, + TokenType.ANTI, + TokenType.APPLY, + TokenType.ASC, + TokenType.ASOF, + TokenType.AUTO_INCREMENT, + TokenType.BEGIN, + TokenType.BPCHAR, + TokenType.CACHE, + TokenType.CASE, + TokenType.COLLATE, + TokenType.COMMAND, + TokenType.COMMENT, + TokenType.COMMIT, + TokenType.CONSTRAINT, + TokenType.COPY, + TokenType.CUBE, + TokenType.CURRENT_SCHEMA, + TokenType.DEFAULT, + TokenType.DELETE, + TokenType.DESC, + TokenType.DESCRIBE, + TokenType.DETACH, + TokenType.DICTIONARY, + TokenType.DIV, + TokenType.END, + TokenType.EXECUTE, + TokenType.EXPORT, + TokenType.ESCAPE, + TokenType.FALSE, + TokenType.FIRST, + TokenType.FILTER, + TokenType.FINAL, + TokenType.FORMAT, + TokenType.FULL, + TokenType.GET, + TokenType.IDENTIFIER, + TokenType.IS, + TokenType.ISNULL, + TokenType.INTERVAL, + TokenType.KEEP, + TokenType.KILL, + TokenType.LEFT, + TokenType.LIMIT, + TokenType.LOAD, + TokenType.LOCK, + TokenType.MATCH, + TokenType.MERGE, + TokenType.NATURAL, + TokenType.NEXT, + TokenType.OFFSET, + TokenType.OPERATOR, + TokenType.ORDINALITY, + TokenType.OVER, + TokenType.OVERLAPS, + TokenType.OVERWRITE, + TokenType.PARTITION, + TokenType.PERCENT, + TokenType.PIVOT, + TokenType.PRAGMA, + TokenType.PUT, + TokenType.RANGE, + TokenType.RECURSIVE, + TokenType.REFERENCES, + TokenType.REFRESH, + TokenType.RENAME, + TokenType.REPLACE, + TokenType.RIGHT, + TokenType.ROLLUP, + TokenType.ROW, + TokenType.ROWS, + TokenType.SEMI, + TokenType.SET, + TokenType.SETTINGS, + TokenType.SHOW, + TokenType.TEMPORARY, + TokenType.TOP, + TokenType.TRUE, + TokenType.TRUNCATE, + TokenType.UNIQUE, + TokenType.UNNEST, + TokenType.UNPIVOT, + TokenType.UPDATE, + TokenType.USE, + TokenType.VOLATILE, + TokenType.WINDOW, + *ALTERABLES, + *CREATABLES, + *SUBQUERY_PREDICATES, + *TYPE_TOKENS, + *NO_PAREN_FUNCTIONS, + } + ID_VAR_TOKENS.remove(TokenType.UNION) + + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { + TokenType.ANTI, + TokenType.ASOF, + TokenType.FULL, + TokenType.LEFT, + TokenType.LOCK, + TokenType.NATURAL, + TokenType.RIGHT, + TokenType.SEMI, + TokenType.WINDOW, + } + + ALIAS_TOKENS = ID_VAR_TOKENS + + COLON_PLACEHOLDER_TOKENS = ID_VAR_TOKENS + + ARRAY_CONSTRUCTORS = { + "ARRAY": exp.Array, + "LIST": exp.List, + } + + COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS} + + UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} + + TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"} + + FUNC_TOKENS = { + TokenType.COLLATE, + TokenType.COMMAND, + TokenType.CURRENT_DATE, + TokenType.CURRENT_DATETIME, + TokenType.CURRENT_SCHEMA, + TokenType.CURRENT_TIMESTAMP, + TokenType.CURRENT_TIME, + TokenType.CURRENT_USER, + TokenType.CURRENT_CATALOG, + TokenType.FILTER, + TokenType.FIRST, + TokenType.FORMAT, + TokenType.GET, + TokenType.GLOB, + TokenType.IDENTIFIER, + TokenType.INDEX, + TokenType.ISNULL, + TokenType.ILIKE, + TokenType.INSERT, + TokenType.LIKE, + TokenType.LOCALTIME, + TokenType.LOCALTIMESTAMP, + TokenType.MERGE, + TokenType.NEXT, + TokenType.OFFSET, + TokenType.PRIMARY_KEY, + TokenType.RANGE, + TokenType.REPLACE, + TokenType.RLIKE, + TokenType.ROW, + TokenType.SESSION_USER, + TokenType.UNNEST, + TokenType.VAR, + TokenType.LEFT, + TokenType.RIGHT, + TokenType.SEQUENCE, + TokenType.DATE, + TokenType.DATETIME, + TokenType.TABLE, + TokenType.TIMESTAMP, + TokenType.TIMESTAMPTZ, + TokenType.TRUNCATE, + TokenType.UTC_DATE, + TokenType.UTC_TIME, + TokenType.UTC_TIMESTAMP, + TokenType.WINDOW, + TokenType.XOR, + *TYPE_TOKENS, + *SUBQUERY_PREDICATES, + } + + CONJUNCTION: t.Dict[TokenType, t.Type[exp.Expression]] = { + TokenType.AND: exp.And, + } + + ASSIGNMENT: t.Dict[TokenType, t.Type[exp.Expression]] = { + TokenType.COLON_EQ: exp.PropertyEQ, + } + + DISJUNCTION: t.Dict[TokenType, t.Type[exp.Expression]] = { + TokenType.OR: exp.Or, + } + + EQUALITY = { + TokenType.EQ: exp.EQ, + TokenType.NEQ: exp.NEQ, + TokenType.NULLSAFE_EQ: exp.NullSafeEQ, + } + + COMPARISON = { + TokenType.GT: exp.GT, + TokenType.GTE: exp.GTE, + TokenType.LT: exp.LT, + TokenType.LTE: exp.LTE, + } + + BITWISE = { + TokenType.AMP: exp.BitwiseAnd, + TokenType.CARET: exp.BitwiseXor, + TokenType.PIPE: exp.BitwiseOr, + } + + TERM = { + TokenType.DASH: exp.Sub, + TokenType.PLUS: exp.Add, + TokenType.MOD: exp.Mod, + TokenType.COLLATE: exp.Collate, + } + + FACTOR = { + TokenType.DIV: exp.IntDiv, + TokenType.LR_ARROW: exp.Distance, + TokenType.SLASH: exp.Div, + TokenType.STAR: exp.Mul, + } + + EXPONENT: t.Dict[TokenType, t.Type[exp.Expression]] = {} + + TIMES = { + TokenType.TIME, + TokenType.TIMETZ, + } + + TIMESTAMPS = { + TokenType.TIMESTAMP, + TokenType.TIMESTAMPNTZ, + TokenType.TIMESTAMPTZ, + TokenType.TIMESTAMPLTZ, + *TIMES, + } + + SET_OPERATIONS = { + TokenType.UNION, + TokenType.INTERSECT, + TokenType.EXCEPT, + } + + JOIN_METHODS = { + TokenType.ASOF, + TokenType.NATURAL, + TokenType.POSITIONAL, + } + + JOIN_SIDES = { + TokenType.LEFT, + TokenType.RIGHT, + TokenType.FULL, + } + + JOIN_KINDS = { + TokenType.ANTI, + TokenType.CROSS, + TokenType.INNER, + TokenType.OUTER, + TokenType.SEMI, + TokenType.STRAIGHT_JOIN, + } + + JOIN_HINTS: t.Set[str] = set() + + LAMBDAS = { + TokenType.ARROW: lambda self, expressions: self.expression( + exp.Lambda, + this=self._replace_lambda( + self._parse_disjunction(), + expressions, + ), + expressions=expressions, + ), + TokenType.FARROW: lambda self, expressions: self.expression( + exp.Kwarg, + this=exp.var(expressions[0].name), + expression=self._parse_disjunction(), + ), + } + + COLUMN_OPERATORS = { + TokenType.DOT: None, + TokenType.DOTCOLON: lambda self, this, to: self.expression( + exp.JSONCast, + this=this, + to=to, + ), + TokenType.DCOLON: lambda self, this, to: self.build_cast( + strict=self.STRICT_CAST, this=this, to=to + ), + TokenType.ARROW: lambda self, this, path: self.expression( + exp.JSONExtract, + this=this, + expression=self.dialect.to_json_path(path), + only_json_types=self.JSON_ARROWS_REQUIRE_JSON_TYPE, + ), + TokenType.DARROW: lambda self, this, path: self.expression( + exp.JSONExtractScalar, + this=this, + expression=self.dialect.to_json_path(path), + only_json_types=self.JSON_ARROWS_REQUIRE_JSON_TYPE, + scalar_only=self.dialect.JSON_EXTRACT_SCALAR_SCALAR_ONLY, + ), + TokenType.HASH_ARROW: lambda self, this, path: self.expression( + exp.JSONBExtract, + this=this, + expression=path, + ), + TokenType.DHASH_ARROW: lambda self, this, path: self.expression( + exp.JSONBExtractScalar, + this=this, + expression=path, + ), + TokenType.PLACEHOLDER: lambda self, this, key: self.expression( + exp.JSONBContains, + this=this, + expression=key, + ), + } + + CAST_COLUMN_OPERATORS = { + TokenType.DOTCOLON, + TokenType.DCOLON, + } + + EXPRESSION_PARSERS = { + exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), + exp.Column: lambda self: self._parse_column(), + exp.ColumnDef: lambda self: self._parse_column_def(self._parse_column()), + exp.Condition: lambda self: self._parse_disjunction(), + exp.DataType: lambda self: self._parse_types( + allow_identifiers=False, schema=True + ), + exp.Expression: lambda self: self._parse_expression(), + exp.From: lambda self: self._parse_from(joins=True), + exp.GrantPrincipal: lambda self: self._parse_grant_principal(), + exp.GrantPrivilege: lambda self: self._parse_grant_privilege(), + exp.Group: lambda self: self._parse_group(), + exp.Having: lambda self: self._parse_having(), + exp.Hint: lambda self: self._parse_hint_body(), + exp.Identifier: lambda self: self._parse_id_var(), + exp.Join: lambda self: self._parse_join(), + exp.Lambda: lambda self: self._parse_lambda(), + exp.Lateral: lambda self: self._parse_lateral(), + exp.Limit: lambda self: self._parse_limit(), + exp.Offset: lambda self: self._parse_offset(), + exp.Order: lambda self: self._parse_order(), + exp.Ordered: lambda self: self._parse_ordered(), + exp.Properties: lambda self: self._parse_properties(), + exp.PartitionedByProperty: lambda self: self._parse_partitioned_by(), + exp.Qualify: lambda self: self._parse_qualify(), + exp.Returning: lambda self: self._parse_returning(), + exp.Select: lambda self: self._parse_select(), + exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY), + exp.Table: lambda self: self._parse_table_parts(), + exp.TableAlias: lambda self: self._parse_table_alias(), + exp.Tuple: lambda self: self._parse_value(values=False), + exp.Whens: lambda self: self._parse_when_matched(), + exp.Where: lambda self: self._parse_where(), + exp.Window: lambda self: self._parse_named_window(), + exp.With: lambda self: self._parse_with(), + "JOIN_TYPE": lambda self: self._parse_join_parts(), + } + + STATEMENT_PARSERS = { + TokenType.ALTER: lambda self: self._parse_alter(), + TokenType.ANALYZE: lambda self: self._parse_analyze(), + TokenType.BEGIN: lambda self: self._parse_transaction(), + TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.COMMENT: lambda self: self._parse_comment(), + TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), + TokenType.COPY: lambda self: self._parse_copy(), + TokenType.CREATE: lambda self: self._parse_create(), + TokenType.DELETE: lambda self: self._parse_delete(), + TokenType.DESC: lambda self: self._parse_describe(), + TokenType.DESCRIBE: lambda self: self._parse_describe(), + TokenType.DROP: lambda self: self._parse_drop(), + TokenType.GRANT: lambda self: self._parse_grant(), + TokenType.REVOKE: lambda self: self._parse_revoke(), + TokenType.INSERT: lambda self: self._parse_insert(), + TokenType.KILL: lambda self: self._parse_kill(), + TokenType.LOAD: lambda self: self._parse_load(), + TokenType.MERGE: lambda self: self._parse_merge(), + TokenType.PIVOT: lambda self: self._parse_simplified_pivot(), + TokenType.PRAGMA: lambda self: self.expression( + exp.Pragma, this=self._parse_expression() + ), + TokenType.REFRESH: lambda self: self._parse_refresh(), + TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), + TokenType.SET: lambda self: self._parse_set(), + TokenType.TRUNCATE: lambda self: self._parse_truncate_table(), + TokenType.UNCACHE: lambda self: self._parse_uncache(), + TokenType.UNPIVOT: lambda self: self._parse_simplified_pivot(is_unpivot=True), + TokenType.UPDATE: lambda self: self._parse_update(), + TokenType.USE: lambda self: self._parse_use(), + TokenType.SEMICOLON: lambda self: exp.Semicolon(), + } + + UNARY_PARSERS = { + TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op + TokenType.NOT: lambda self: self.expression( + exp.Not, this=self._parse_equality() + ), + TokenType.TILDA: lambda self: self.expression( + exp.BitwiseNot, this=self._parse_unary() + ), + TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()), + TokenType.PIPE_SLASH: lambda self: self.expression( + exp.Sqrt, this=self._parse_unary() + ), + TokenType.DPIPE_SLASH: lambda self: self.expression( + exp.Cbrt, this=self._parse_unary() + ), + } + + STRING_PARSERS = { + TokenType.HEREDOC_STRING: lambda self, token: self.expression( + exp.RawString, token=token + ), + TokenType.NATIONAL_STRING: lambda self, token: self.expression( + exp.National, token=token + ), + TokenType.RAW_STRING: lambda self, token: self.expression( + exp.RawString, token=token + ), + TokenType.STRING: lambda self, token: self.expression( + exp.Literal, token=token, is_string=True + ), + TokenType.UNICODE_STRING: lambda self, token: self.expression( + exp.UnicodeString, + token=token, + escape=self._match_text_seq("UESCAPE") and self._parse_string(), + ), + } + + NUMERIC_PARSERS = { + TokenType.BIT_STRING: lambda self, token: self.expression( + exp.BitString, token=token + ), + TokenType.BYTE_STRING: lambda self, token: self.expression( + exp.ByteString, + token=token, + is_bytes=self.dialect.BYTE_STRING_IS_BYTES_TYPE or None, + ), + TokenType.HEX_STRING: lambda self, token: self.expression( + exp.HexString, + token=token, + is_integer=self.dialect.HEX_STRING_IS_INTEGER_TYPE or None, + ), + TokenType.NUMBER: lambda self, token: self.expression( + exp.Literal, token=token, is_string=False + ), + } + + PRIMARY_PARSERS = { + **STRING_PARSERS, + **NUMERIC_PARSERS, + TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), + TokenType.NULL: lambda self, _: self.expression(exp.Null), + TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), + TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), + TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), + TokenType.STAR: lambda self, _: self._parse_star_ops(), + } + + PLACEHOLDER_PARSERS = { + TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder), + TokenType.PARAMETER: lambda self: self._parse_parameter(), + TokenType.COLON: lambda self: ( + self.expression(exp.Placeholder, this=self._prev.text) + if self._match_set(self.COLON_PLACEHOLDER_TOKENS) + else None + ), + } + + RANGE_PARSERS = { + TokenType.AT_GT: binary_range_parser(exp.ArrayContainsAll), + TokenType.BETWEEN: lambda self, this: self._parse_between(this), + TokenType.GLOB: binary_range_parser(exp.Glob), + TokenType.ILIKE: binary_range_parser(exp.ILike), + TokenType.IN: lambda self, this: self._parse_in(this), + TokenType.IRLIKE: binary_range_parser(exp.RegexpILike), + TokenType.IS: lambda self, this: self._parse_is(this), + TokenType.LIKE: binary_range_parser(exp.Like), + TokenType.LT_AT: binary_range_parser(exp.ArrayContainsAll, reverse_args=True), + TokenType.OVERLAPS: binary_range_parser(exp.Overlaps), + TokenType.RLIKE: binary_range_parser(exp.RegexpLike), + TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo), + TokenType.FOR: lambda self, this: self._parse_comprehension(this), + TokenType.QMARK_AMP: binary_range_parser(exp.JSONBContainsAllTopKeys), + TokenType.QMARK_PIPE: binary_range_parser(exp.JSONBContainsAnyTopKeys), + TokenType.HASH_DASH: binary_range_parser(exp.JSONBDeleteAtPath), + TokenType.ADJACENT: binary_range_parser(exp.Adjacent), + TokenType.OPERATOR: lambda self, this: self._parse_operator(this), + TokenType.AMP_LT: binary_range_parser(exp.ExtendsLeft), + TokenType.AMP_GT: binary_range_parser(exp.ExtendsRight), + } + + PIPE_SYNTAX_TRANSFORM_PARSERS = { + "AGGREGATE": lambda self, query: self._parse_pipe_syntax_aggregate(query), + "AS": lambda self, query: self._build_pipe_cte( + query, [exp.Star()], self._parse_table_alias() + ), + "EXTEND": lambda self, query: self._parse_pipe_syntax_extend(query), + "LIMIT": lambda self, query: self._parse_pipe_syntax_limit(query), + "ORDER BY": lambda self, query: query.order_by( + self._parse_order(), append=False, copy=False + ), + "PIVOT": lambda self, query: self._parse_pipe_syntax_pivot(query), + "SELECT": lambda self, query: self._parse_pipe_syntax_select(query), + "TABLESAMPLE": lambda self, query: self._parse_pipe_syntax_tablesample(query), + "UNPIVOT": lambda self, query: self._parse_pipe_syntax_pivot(query), + "WHERE": lambda self, query: query.where(self._parse_where(), copy=False), + } + + PROPERTY_PARSERS: t.Dict[str, t.Callable] = { + "ALLOWED_VALUES": lambda self: self.expression( + exp.AllowedValuesProperty, expressions=self._parse_csv(self._parse_primary) + ), + "ALGORITHM": lambda self: self._parse_property_assignment( + exp.AlgorithmProperty + ), + "AUTO": lambda self: self._parse_auto_property(), + "AUTO_INCREMENT": lambda self: self._parse_property_assignment( + exp.AutoIncrementProperty + ), + "BACKUP": lambda self: self.expression( + exp.BackupProperty, this=self._parse_var(any_token=True) + ), + "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), + "CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs), + "CHARACTER SET": lambda self, **kwargs: self._parse_character_set(**kwargs), + "CHECKSUM": lambda self: self._parse_checksum(), + "CLUSTER BY": lambda self: self._parse_cluster(), + "CLUSTERED": lambda self: self._parse_clustered_by(), + "COLLATE": lambda self, **kwargs: self._parse_property_assignment( + exp.CollateProperty, **kwargs + ), + "COMMENT": lambda self: self._parse_property_assignment( + exp.SchemaCommentProperty + ), + "CONTAINS": lambda self: self._parse_contains_property(), + "COPY": lambda self: self._parse_copy_property(), + "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs), + "DATA_DELETION": lambda self: self._parse_data_deletion_property(), + "DEFINER": lambda self: self._parse_definer(), + "DETERMINISTIC": lambda self: self.expression( + exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") + ), + "DISTRIBUTED": lambda self: self._parse_distributed_property(), + "DUPLICATE": lambda self: self._parse_composite_key_property( + exp.DuplicateKeyProperty + ), + "DYNAMIC": lambda self: self.expression(exp.DynamicProperty), + "DISTKEY": lambda self: self._parse_distkey(), + "DISTSTYLE": lambda self: self._parse_property_assignment( + exp.DistStyleProperty + ), + "EMPTY": lambda self: self.expression(exp.EmptyProperty), + "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty), + "ENVIRONMENT": lambda self: self.expression( + exp.EnviromentProperty, + expressions=self._parse_wrapped_csv(self._parse_assignment), + ), + "EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), + "EXTERNAL": lambda self: self.expression(exp.ExternalProperty), + "FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs), + "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "FREESPACE": lambda self: self._parse_freespace(), + "GLOBAL": lambda self: self.expression(exp.GlobalProperty), + "HEAP": lambda self: self.expression(exp.HeapProperty), + "ICEBERG": lambda self: self.expression(exp.IcebergProperty), + "IMMUTABLE": lambda self: self.expression( + exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") + ), + "INHERITS": lambda self: self.expression( + exp.InheritsProperty, expressions=self._parse_wrapped_csv(self._parse_table) + ), + "INPUT": lambda self: self.expression( + exp.InputModelProperty, this=self._parse_schema() + ), + "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs), + "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), + "LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"), + "LIFETIME": lambda self: self._parse_dict_range(this="LIFETIME"), + "LIKE": lambda self: self._parse_create_like(), + "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), + "LOCK": lambda self: self._parse_locking(), + "LOCKING": lambda self: self._parse_locking(), + "LOG": lambda self, **kwargs: self._parse_log(**kwargs), + "MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty), + "MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs), + "MODIFIES": lambda self: self._parse_modifies_property(), + "MULTISET": lambda self: self.expression(exp.SetProperty, multi=True), + "NO": lambda self: self._parse_no_property(), + "ON": lambda self: self._parse_on_property(), + "ORDER BY": lambda self: self._parse_order(skip_order_token=True), + "OUTPUT": lambda self: self.expression( + exp.OutputModelProperty, this=self._parse_schema() + ), + "PARTITION": lambda self: self._parse_partitioned_of(), + "PARTITION BY": lambda self: self._parse_partitioned_by(), + "PARTITIONED BY": lambda self: self._parse_partitioned_by(), + "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), + "PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True), + "RANGE": lambda self: self._parse_dict_range(this="RANGE"), + "READS": lambda self: self._parse_reads_property(), + "REMOTE": lambda self: self._parse_remote_with_connection(), + "RETURNS": lambda self: self._parse_returns(), + "STRICT": lambda self: self.expression(exp.StrictProperty), + "STREAMING": lambda self: self.expression(exp.StreamingTableProperty), + "ROW": lambda self: self._parse_row(), + "ROW_FORMAT": lambda self: self._parse_property_assignment( + exp.RowFormatProperty + ), + "SAMPLE": lambda self: self.expression( + exp.SampleProperty, + this=self._match_text_seq("BY") and self._parse_bitwise(), + ), + "SECURE": lambda self: self.expression(exp.SecureProperty), + "SECURITY": lambda self: self._parse_security(), + "SET": lambda self: self.expression(exp.SetProperty, multi=False), + "SETTINGS": lambda self: self._parse_settings_property(), + "SHARING": lambda self: self._parse_property_assignment(exp.SharingProperty), + "SORTKEY": lambda self: self._parse_sortkey(), + "SOURCE": lambda self: self._parse_dict_property(this="SOURCE"), + "STABLE": lambda self: self.expression( + exp.StabilityProperty, this=exp.Literal.string("STABLE") + ), + "STORED": lambda self: self._parse_stored(), + "SYSTEM_VERSIONING": lambda self: self._parse_system_versioning_property(), + "TBLPROPERTIES": lambda self: self._parse_wrapped_properties(), + "TEMP": lambda self: self.expression(exp.TemporaryProperty), + "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), + "TO": lambda self: self._parse_to_table(), + "TRANSIENT": lambda self: self.expression(exp.TransientProperty), + "TRANSFORM": lambda self: self.expression( + exp.TransformModelProperty, + expressions=self._parse_wrapped_csv(self._parse_expression), + ), + "TTL": lambda self: self._parse_ttl(), + "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "UNLOGGED": lambda self: self.expression(exp.UnloggedProperty), + "VOLATILE": lambda self: self._parse_volatile_property(), + "WITH": lambda self: self._parse_with_property(), + } + + CONSTRAINT_PARSERS = { + "AUTOINCREMENT": lambda self: self._parse_auto_increment(), + "AUTO_INCREMENT": lambda self: self._parse_auto_increment(), + "CASESPECIFIC": lambda self: self.expression( + exp.CaseSpecificColumnConstraint, not_=False + ), + "CHARACTER SET": lambda self: self.expression( + exp.CharacterSetColumnConstraint, this=self._parse_var_or_string() + ), + "CHECK": lambda self: self.expression( + exp.CheckColumnConstraint, + this=self._parse_wrapped(self._parse_assignment), + enforced=self._match_text_seq("ENFORCED"), + ), + "COLLATE": lambda self: self.expression( + exp.CollateColumnConstraint, + this=self._parse_identifier() or self._parse_column(), + ), + "COMMENT": lambda self: self.expression( + exp.CommentColumnConstraint, this=self._parse_string() + ), + "COMPRESS": lambda self: self._parse_compress(), + "CLUSTERED": lambda self: self.expression( + exp.ClusteredColumnConstraint, + this=self._parse_wrapped_csv(self._parse_ordered), + ), + "NONCLUSTERED": lambda self: self.expression( + exp.NonClusteredColumnConstraint, + this=self._parse_wrapped_csv(self._parse_ordered), + ), + "DEFAULT": lambda self: self.expression( + exp.DefaultColumnConstraint, this=self._parse_bitwise() + ), + "ENCODE": lambda self: self.expression( + exp.EncodeColumnConstraint, this=self._parse_var() + ), + "EPHEMERAL": lambda self: self.expression( + exp.EphemeralColumnConstraint, this=self._parse_bitwise() + ), + "EXCLUDE": lambda self: self.expression( + exp.ExcludeColumnConstraint, this=self._parse_index_params() + ), + "FOREIGN KEY": lambda self: self._parse_foreign_key(), + "FORMAT": lambda self: self.expression( + exp.DateFormatColumnConstraint, this=self._parse_var_or_string() + ), + "GENERATED": lambda self: self._parse_generated_as_identity(), + "IDENTITY": lambda self: self._parse_auto_increment(), + "INLINE": lambda self: self._parse_inline(), + "LIKE": lambda self: self._parse_create_like(), + "NOT": lambda self: self._parse_not_constraint(), + "NULL": lambda self: self.expression( + exp.NotNullColumnConstraint, allow_null=True + ), + "ON": lambda self: ( + self._match(TokenType.UPDATE) + and self.expression( + exp.OnUpdateColumnConstraint, this=self._parse_function() + ) + ) + or self.expression(exp.OnProperty, this=self._parse_id_var()), + "PATH": lambda self: self.expression( + exp.PathColumnConstraint, this=self._parse_string() + ), + "PERIOD": lambda self: self._parse_period_for_system_time(), + "PRIMARY KEY": lambda self: self._parse_primary_key(), + "REFERENCES": lambda self: self._parse_references(match=False), + "TITLE": lambda self: self.expression( + exp.TitleColumnConstraint, this=self._parse_var_or_string() + ), + "TTL": lambda self: self.expression( + exp.MergeTreeTTL, expressions=[self._parse_bitwise()] + ), + "UNIQUE": lambda self: self._parse_unique(), + "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint), + "WITH": lambda self: self.expression( + exp.Properties, expressions=self._parse_wrapped_properties() + ), + "BUCKET": lambda self: self._parse_partitioned_by_bucket_or_truncate(), + "TRUNCATE": lambda self: self._parse_partitioned_by_bucket_or_truncate(), + } + + def _parse_partitioned_by_bucket_or_truncate(self) -> t.Optional[exp.Expression]: + if not self._match(TokenType.L_PAREN, advance=False): + # Partitioning by bucket or truncate follows the syntax: + # PARTITION BY (BUCKET(..) | TRUNCATE(..)) + # If we don't have parenthesis after each keyword, we should instead parse this as an identifier + self._retreat(self._index - 1) + return None + + klass = ( + exp.PartitionedByBucket + if self._prev.text.upper() == "BUCKET" + else exp.PartitionByTruncate + ) + + args = self._parse_wrapped_csv( + lambda: self._parse_primary() or self._parse_column() + ) + this, expression = seq_get(args, 0), seq_get(args, 1) + + if isinstance(this, exp.Literal): + # Check for Iceberg partition transforms (bucket / truncate) and ensure their arguments are in the right order + # - For Hive, it's `bucket(, )` or `truncate(, )` + # - For Trino, it's reversed - `bucket(, )` or `truncate(, )` + # Both variants are canonicalized in the latter i.e `bucket(, )` + # + # Hive ref: https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html#querying-iceberg-partitioning + # Trino ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties + this, expression = expression, this + + return self.expression(klass, this=this, expression=expression) + + ALTER_PARSERS = { + "ADD": lambda self: self._parse_alter_table_add(), + "AS": lambda self: self._parse_select(), + "ALTER": lambda self: self._parse_alter_table_alter(), + "CLUSTER BY": lambda self: self._parse_cluster(wrapped=True), + "DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()), + "DROP": lambda self: self._parse_alter_table_drop(), + "RENAME": lambda self: self._parse_alter_table_rename(), + "SET": lambda self: self._parse_alter_table_set(), + "SWAP": lambda self: self.expression( + exp.SwapTable, + this=self._match(TokenType.WITH) and self._parse_table(schema=True), + ), + } + + ALTER_ALTER_PARSERS = { + "DISTKEY": lambda self: self._parse_alter_diststyle(), + "DISTSTYLE": lambda self: self._parse_alter_diststyle(), + "SORTKEY": lambda self: self._parse_alter_sortkey(), + "COMPOUND": lambda self: self._parse_alter_sortkey(compound=True), + } + + SCHEMA_UNNAMED_CONSTRAINTS = { + "CHECK", + "EXCLUDE", + "FOREIGN KEY", + "LIKE", + "PERIOD", + "PRIMARY KEY", + "UNIQUE", + "BUCKET", + "TRUNCATE", + } + + NO_PAREN_FUNCTION_PARSERS = { + "ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()), + "CASE": lambda self: self._parse_case(), + "CONNECT_BY_ROOT": lambda self: self.expression( + exp.ConnectByRoot, this=self._parse_column() + ), + "IF": lambda self: self._parse_if(), + } + + INVALID_FUNC_NAME_TOKENS = { + TokenType.IDENTIFIER, + TokenType.STRING, + } + + FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} + + KEY_VALUE_DEFINITIONS = (exp.Alias, exp.EQ, exp.PropertyEQ, exp.Slice) + + FUNCTION_PARSERS = { + **{ + name: lambda self: self._parse_max_min_by(exp.ArgMax) + for name in exp.ArgMax.sql_names() + }, + **{ + name: lambda self: self._parse_max_min_by(exp.ArgMin) + for name in exp.ArgMin.sql_names() + }, + "CAST": lambda self: self._parse_cast(self.STRICT_CAST), + "CEIL": lambda self: self._parse_ceil_floor(exp.Ceil), + "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), + "DECODE": lambda self: self._parse_decode(), + "EXTRACT": lambda self: self._parse_extract(), + "FLOOR": lambda self: self._parse_ceil_floor(exp.Floor), + "GAP_FILL": lambda self: self._parse_gap_fill(), + "INITCAP": lambda self: self._parse_initcap(), + "JSON_OBJECT": lambda self: self._parse_json_object(), + "JSON_OBJECTAGG": lambda self: self._parse_json_object(agg=True), + "JSON_TABLE": lambda self: self._parse_json_table(), + "MATCH": lambda self: self._parse_match_against(), + "NORMALIZE": lambda self: self._parse_normalize(), + "OPENJSON": lambda self: self._parse_open_json(), + "OVERLAY": lambda self: self._parse_overlay(), + "POSITION": lambda self: self._parse_position(), + "SAFE_CAST": lambda self: self._parse_cast(False, safe=True), + "STRING_AGG": lambda self: self._parse_string_agg(), + "SUBSTRING": lambda self: self._parse_substring(), + "TRIM": lambda self: self._parse_trim(), + "TRY_CAST": lambda self: self._parse_cast(False, safe=True), + "TRY_CONVERT": lambda self: self._parse_convert(False, safe=True), + "XMLELEMENT": lambda self: self.expression( + exp.XMLElement, + this=self._match_text_seq("NAME") and self._parse_id_var(), + expressions=self._match(TokenType.COMMA) + and self._parse_csv(self._parse_expression), + ), + "XMLTABLE": lambda self: self._parse_xml_table(), + } + + QUERY_MODIFIER_PARSERS = { + TokenType.MATCH_RECOGNIZE: lambda self: ( + "match", + self._parse_match_recognize(), + ), + TokenType.PREWHERE: lambda self: ("prewhere", self._parse_prewhere()), + TokenType.WHERE: lambda self: ("where", self._parse_where()), + TokenType.GROUP_BY: lambda self: ("group", self._parse_group()), + TokenType.HAVING: lambda self: ("having", self._parse_having()), + TokenType.QUALIFY: lambda self: ("qualify", self._parse_qualify()), + TokenType.WINDOW: lambda self: ("windows", self._parse_window_clause()), + TokenType.ORDER_BY: lambda self: ("order", self._parse_order()), + TokenType.LIMIT: lambda self: ("limit", self._parse_limit()), + TokenType.FETCH: lambda self: ("limit", self._parse_limit()), + TokenType.OFFSET: lambda self: ("offset", self._parse_offset()), + TokenType.FOR: lambda self: ("locks", self._parse_locks()), + TokenType.LOCK: lambda self: ("locks", self._parse_locks()), + TokenType.TABLE_SAMPLE: lambda self: ( + "sample", + self._parse_table_sample(as_modifier=True), + ), + TokenType.USING: lambda self: ( + "sample", + self._parse_table_sample(as_modifier=True), + ), + TokenType.CLUSTER_BY: lambda self: ( + "cluster", + self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), + ), + TokenType.DISTRIBUTE_BY: lambda self: ( + "distribute", + self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY), + ), + TokenType.SORT_BY: lambda self: ( + "sort", + self._parse_sort(exp.Sort, TokenType.SORT_BY), + ), + TokenType.CONNECT_BY: lambda self: ( + "connect", + self._parse_connect(skip_start_token=True), + ), + TokenType.START_WITH: lambda self: ("connect", self._parse_connect()), + } + QUERY_MODIFIER_TOKENS = set(QUERY_MODIFIER_PARSERS) + + SET_PARSERS = { + "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), + "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), + "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), + "TRANSACTION": lambda self: self._parse_set_transaction(), + } + + SHOW_PARSERS: t.Dict[str, t.Callable] = {} + + TYPE_LITERAL_PARSERS = { + exp.DataType.Type.JSON: lambda self, this, _: self.expression( + exp.ParseJSON, this=this + ), + } + + TYPE_CONVERTERS: t.Dict[ + exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType] + ] = {} + + DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN} + + PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE} + + TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + TRANSACTION_CHARACTERISTICS: OPTIONS_TYPE = { + "ISOLATION": ( + ("LEVEL", "REPEATABLE", "READ"), + ("LEVEL", "READ", "COMMITTED"), + ("LEVEL", "READ", "UNCOMITTED"), + ("LEVEL", "SERIALIZABLE"), + ), + "READ": ("WRITE", "ONLY"), + } + + CONFLICT_ACTIONS: OPTIONS_TYPE = dict.fromkeys( + ("ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK", "UPDATE"), tuple() + ) + CONFLICT_ACTIONS["DO"] = ("NOTHING", "UPDATE") + + CREATE_SEQUENCE: OPTIONS_TYPE = { + "SCALE": ("EXTEND", "NOEXTEND"), + "SHARD": ("EXTEND", "NOEXTEND"), + "NO": ("CYCLE", "CACHE", "MAXVALUE", "MINVALUE"), + **dict.fromkeys( + ( + "SESSION", + "GLOBAL", + "KEEP", + "NOKEEP", + "ORDER", + "NOORDER", + "NOCACHE", + "CYCLE", + "NOCYCLE", + "NOMINVALUE", + "NOMAXVALUE", + "NOSCALE", + "NOSHARD", + ), + tuple(), + ), + } + + ISOLATED_LOADING_OPTIONS: OPTIONS_TYPE = {"FOR": ("ALL", "INSERT", "NONE")} + + USABLES: OPTIONS_TYPE = dict.fromkeys( + ("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA", "CATALOG"), tuple() + ) + + CAST_ACTIONS: OPTIONS_TYPE = dict.fromkeys(("RENAME", "ADD"), ("FIELDS",)) + + SCHEMA_BINDING_OPTIONS: OPTIONS_TYPE = { + "TYPE": ("EVOLUTION",), + **dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()), + } + + PROCEDURE_OPTIONS: OPTIONS_TYPE = {} + + EXECUTE_AS_OPTIONS: OPTIONS_TYPE = dict.fromkeys( + ("CALLER", "SELF", "OWNER"), tuple() + ) + + KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = { + "NOT": ("ENFORCED",), + "MATCH": ( + "FULL", + "PARTIAL", + "SIMPLE", + ), + "INITIALLY": ("DEFERRED", "IMMEDIATE"), + "USING": ( + "BTREE", + "HASH", + ), + **dict.fromkeys(("DEFERRABLE", "NORELY", "RELY"), tuple()), + } + + WINDOW_EXCLUDE_OPTIONS: OPTIONS_TYPE = { + "NO": ("OTHERS",), + "CURRENT": ("ROW",), + **dict.fromkeys(("GROUP", "TIES"), tuple()), + } + + INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} + + CLONE_KEYWORDS = {"CLONE", "COPY"} + HISTORICAL_DATA_PREFIX = {"AT", "BEFORE", "END"} + HISTORICAL_DATA_KIND = {"OFFSET", "STATEMENT", "STREAM", "TIMESTAMP", "VERSION"} + + OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS", "WITH"} + + OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN} + + TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE} + + VIEW_ATTRIBUTES = {"ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"} + + WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.RANGE, TokenType.ROWS} + WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} + WINDOW_SIDES = {"FOLLOWING", "PRECEDING"} + + JSON_KEY_VALUE_SEPARATOR_TOKENS = {TokenType.COLON, TokenType.COMMA, TokenType.IS} + + FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT} + + ADD_CONSTRAINT_TOKENS = { + TokenType.CONSTRAINT, + TokenType.FOREIGN_KEY, + TokenType.INDEX, + TokenType.KEY, + TokenType.PRIMARY_KEY, + TokenType.UNIQUE, + } + + DISTINCT_TOKENS = {TokenType.DISTINCT} + + UNNEST_OFFSET_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - SET_OPERATIONS + + SELECT_START_TOKENS = {TokenType.L_PAREN, TokenType.WITH, TokenType.SELECT} + + COPY_INTO_VARLEN_OPTIONS = { + "FILE_FORMAT", + "COPY_OPTIONS", + "FORMAT_OPTIONS", + "CREDENTIAL", + } + + IS_JSON_PREDICATE_KIND = {"VALUE", "SCALAR", "ARRAY", "OBJECT"} + + ODBC_DATETIME_LITERALS: t.Dict[str, t.Type[exp.Expression]] = {} + + ON_CONDITION_TOKENS = {"ERROR", "NULL", "TRUE", "FALSE", "EMPTY"} + + PRIVILEGE_FOLLOW_TOKENS = {TokenType.ON, TokenType.COMMA, TokenType.L_PAREN} + + # The style options for the DESCRIBE statement + DESCRIBE_STYLES = {"ANALYZE", "EXTENDED", "FORMATTED", "HISTORY"} + + SET_ASSIGNMENT_DELIMITERS = {"=", ":=", "TO"} + + # The style options for the ANALYZE statement + ANALYZE_STYLES = { + "BUFFER_USAGE_LIMIT", + "FULL", + "LOCAL", + "NO_WRITE_TO_BINLOG", + "SAMPLE", + "SKIP_LOCKED", + "VERBOSE", + } + + ANALYZE_EXPRESSION_PARSERS = { + "ALL": lambda self: self._parse_analyze_columns(), + "COMPUTE": lambda self: self._parse_analyze_statistics(), + "DELETE": lambda self: self._parse_analyze_delete(), + "DROP": lambda self: self._parse_analyze_histogram(), + "ESTIMATE": lambda self: self._parse_analyze_statistics(), + "LIST": lambda self: self._parse_analyze_list(), + "PREDICATE": lambda self: self._parse_analyze_columns(), + "UPDATE": lambda self: self._parse_analyze_histogram(), + "VALIDATE": lambda self: self._parse_analyze_validate(), + } + + PARTITION_KEYWORDS = {"PARTITION", "SUBPARTITION"} + + AMBIGUOUS_ALIAS_TOKENS = (TokenType.LIMIT, TokenType.OFFSET) + + OPERATION_MODIFIERS: t.Set[str] = set() + + RECURSIVE_CTE_SEARCH_KIND = {"BREADTH", "DEPTH", "CYCLE"} + + MODIFIABLES = (exp.Query, exp.Table, exp.TableFromRows, exp.Values) + + STRICT_CAST = True + + PREFIXED_PIVOT_COLUMNS = False + IDENTIFY_PIVOT_STRINGS = False + + LOG_DEFAULTS_TO_LN = False + + # Whether the table sample clause expects CSV syntax + TABLESAMPLE_CSV = False + + # The default method used for table sampling + DEFAULT_SAMPLING_METHOD: t.Optional[str] = None + + # Whether the SET command needs a delimiter (e.g. "=") for assignments + SET_REQUIRES_ASSIGNMENT_DELIMITER = True + + # Whether the TRIM function expects the characters to trim as its first argument + TRIM_PATTERN_FIRST = False + + # Whether string aliases are supported `SELECT COUNT(*) 'count'` + STRING_ALIASES = False + + # Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand) + MODIFIERS_ATTACHED_TO_SET_OP = True + SET_OP_MODIFIERS = {"order", "limit", "offset"} + + # Whether to parse IF statements that aren't followed by a left parenthesis as commands + NO_PAREN_IF_COMMANDS = True + + # Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres) + JSON_ARROWS_REQUIRE_JSON_TYPE = False + + # Whether the `:` operator is used to extract a value from a VARIANT column + COLON_IS_VARIANT_EXTRACT = False + + # Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause. + # If this is True and '(' is not found, the keyword will be treated as an identifier + VALUES_FOLLOWED_BY_PAREN = True + + # Whether implicit unnesting is supported, e.g. SELECT 1 FROM y.z AS z, z.a (Redshift) + SUPPORTS_IMPLICIT_UNNEST = False + + # Whether or not interval spans are supported, INTERVAL 1 YEAR TO MONTHS + INTERVAL_SPANS = True + + # Whether a PARTITION clause can follow a table reference + SUPPORTS_PARTITION_SELECTION = False + + # Whether the `name AS expr` schema/column constraint requires parentheses around `expr` + WRAPPED_TRANSFORM_COLUMN_CONSTRAINT = True + + # Whether the 'AS' keyword is optional in the CTE definition syntax + OPTIONAL_ALIAS_TOKEN_CTE = True + + # Whether renaming a column with an ALTER statement requires the presence of the COLUMN keyword + ALTER_RENAME_REQUIRES_COLUMN = True + + # Whether Alter statements are allowed to contain Partition specifications + ALTER_TABLE_PARTITIONS = False + + # Whether all join types have the same precedence, i.e., they "naturally" produce a left-deep tree. + # In standard SQL, joins that use the JOIN keyword take higher precedence than comma-joins. That is + # to say, JOIN operators happen before comma operators. This is not the case in some dialects, such + # as BigQuery, where all joins have the same precedence. + JOINS_HAVE_EQUAL_PRECEDENCE = False + + # Whether TIMESTAMP can produce a zone-aware timestamp + ZONE_AWARE_TIMESTAMP_CONSTRUCTOR = False + + # Whether map literals support arbitrary expressions as keys. + # When True, allows complex keys like arrays or literals: {[1, 2]: 3}, {1: 2} (e.g. DuckDB). + # When False, keys are typically restricted to identifiers. + MAP_KEYS_ARE_ARBITRARY_EXPRESSIONS = False + + # Whether JSON_EXTRACT requires a JSON expression as the first argument, e.g this + # is true for Snowflake but not for BigQuery which can also process strings + JSON_EXTRACT_REQUIRES_JSON_EXPRESSION = False + + # Dialects like Databricks support JOINS without join criteria + # Adding an ON TRUE, makes transpilation semantically correct for other dialects + ADD_JOIN_ON_TRUE = False + + # Whether INTERVAL spans with literal format '\d+ hh:[mm:[ss[.ff]]]' + # can omit the span unit `DAY TO MINUTE` or `DAY TO SECOND` + SUPPORTS_OMITTED_INTERVAL_SPAN_UNIT = False + + __slots__ = ( + "error_level", + "error_message_context", + "max_errors", + "dialect", + "sql", + "errors", + "_tokens", + "_index", + "_curr", + "_next", + "_prev", + "_prev_comments", + "_pipe_cte_counter", + ) + + # Autofilled + SHOW_TRIE: t.Dict = {} + SET_TRIE: t.Dict = {} + + def __init__( + self, + error_level: t.Optional[ErrorLevel] = None, + error_message_context: int = 100, + max_errors: int = 3, + dialect: DialectType = None, + ): + from bigframes_vendored.sqlglot.dialects import Dialect + + self.error_level = error_level or ErrorLevel.IMMEDIATE + self.error_message_context = error_message_context + self.max_errors = max_errors + self.dialect = Dialect.get_or_raise(dialect) + self.reset() + + def reset(self): + self.sql = "" + self.errors = [] + self._tokens = [] + self._index = 0 + self._curr = None + self._next = None + self._prev = None + self._prev_comments = None + self._pipe_cte_counter = 0 + + def parse( + self, raw_tokens: t.List[Token], sql: t.Optional[str] = None + ) -> t.List[t.Optional[exp.Expression]]: + """ + Parses a list of tokens and returns a list of syntax trees, one tree + per parsed SQL statement. + + Args: + raw_tokens: The list of tokens. + sql: The original SQL string, used to produce helpful debug messages. + + Returns: + The list of the produced syntax trees. + """ + return self._parse( + parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql + ) + + def parse_into( + self, + expression_types: exp.IntoType, + raw_tokens: t.List[Token], + sql: t.Optional[str] = None, + ) -> t.List[t.Optional[exp.Expression]]: + """ + Parses a list of tokens into a given Expression type. If a collection of Expression + types is given instead, this method will try to parse the token list into each one + of them, stopping at the first for which the parsing succeeds. + + Args: + expression_types: The expression type(s) to try and parse the token list into. + raw_tokens: The list of tokens. + sql: The original SQL string, used to produce helpful debug messages. + + Returns: + The target Expression. + """ + errors = [] + for expression_type in ensure_list(expression_types): + parser = self.EXPRESSION_PARSERS.get(expression_type) + if not parser: + raise TypeError(f"No parser registered for {expression_type}") + + try: + return self._parse(parser, raw_tokens, sql) + except ParseError as e: + e.errors[0]["into_expression"] = expression_type + errors.append(e) + + raise ParseError( + f"Failed to parse '{sql or raw_tokens}' into {expression_types}", + errors=merge_errors(errors), + ) from errors[-1] + + def _parse( + self, + parse_method: t.Callable[[Parser], t.Optional[exp.Expression]], + raw_tokens: t.List[Token], + sql: t.Optional[str] = None, + ) -> t.List[t.Optional[exp.Expression]]: + self.reset() + self.sql = sql or "" + + total = len(raw_tokens) + chunks: t.List[t.List[Token]] = [[]] + + for i, token in enumerate(raw_tokens): + if token.token_type == TokenType.SEMICOLON: + if token.comments: + chunks.append([token]) + + if i < total - 1: + chunks.append([]) + else: + chunks[-1].append(token) + + expressions = [] + + for tokens in chunks: + self._index = -1 + self._tokens = tokens + self._advance() + + expressions.append(parse_method(self)) + + if self._index < len(self._tokens): + self.raise_error("Invalid expression / Unexpected token") + + self.check_errors() + + return expressions + + def check_errors(self) -> None: + """Logs or raises any found errors, depending on the chosen error level setting.""" + if self.error_level == ErrorLevel.WARN: + for error in self.errors: + logger.error(str(error)) + elif self.error_level == ErrorLevel.RAISE and self.errors: + raise ParseError( + concat_messages(self.errors, self.max_errors), + errors=merge_errors(self.errors), + ) + + def raise_error(self, message: str, token: t.Optional[Token] = None) -> None: + """ + Appends an error in the list of recorded errors or raises it, depending on the chosen + error level setting. + """ + token = token or self._curr or self._prev or Token.string("") + formatted_sql, start_context, highlight, end_context = highlight_sql( + sql=self.sql, + positions=[(token.start, token.end)], + context_length=self.error_message_context, + ) + formatted_message = ( + f"{message}. Line {token.line}, Col: {token.col}.\n {formatted_sql}" + ) + + error = ParseError.new( + formatted_message, + description=message, + line=token.line, + col=token.col, + start_context=start_context, + highlight=highlight, + end_context=end_context, + ) + + if self.error_level == ErrorLevel.IMMEDIATE: + raise error + + self.errors.append(error) + + def expression( + self, + exp_class: t.Type[E], + token: t.Optional[Token] = None, + comments: t.Optional[t.List[str]] = None, + **kwargs, + ) -> E: + """ + Creates a new, validated Expression. + + Args: + exp_class: The expression class to instantiate. + comments: An optional list of comments to attach to the expression. + kwargs: The arguments to set for the expression along with their respective values. + + Returns: + The target expression. + """ + if token: + instance = exp_class(this=token.text, **kwargs) + instance.update_positions(token) + else: + instance = exp_class(**kwargs) + instance.add_comments(comments) if comments else self._add_comments(instance) + return self.validate_expression(instance) + + def _add_comments(self, expression: t.Optional[exp.Expression]) -> None: + if expression and self._prev_comments: + expression.add_comments(self._prev_comments) + self._prev_comments = None + + def validate_expression(self, expression: E, args: t.Optional[t.List] = None) -> E: + """ + Validates an Expression, making sure that all its mandatory arguments are set. + + Args: + expression: The expression to validate. + args: An optional list of items that was used to instantiate the expression, if it's a Func. + + Returns: + The validated expression. + """ + if self.error_level != ErrorLevel.IGNORE: + for error_message in expression.error_messages(args): + self.raise_error(error_message) + + return expression + + def _find_sql(self, start: Token, end: Token) -> str: + return self.sql[start.start : end.end + 1] + + def _is_connected(self) -> bool: + return self._prev and self._curr and self._prev.end + 1 == self._curr.start + + def _advance(self, times: int = 1) -> None: + self._index += times + self._curr = seq_get(self._tokens, self._index) + self._next = seq_get(self._tokens, self._index + 1) + + if self._index > 0: + self._prev = self._tokens[self._index - 1] + self._prev_comments = self._prev.comments + else: + self._prev = None + self._prev_comments = None + + def _retreat(self, index: int) -> None: + if index != self._index: + self._advance(index - self._index) + + def _warn_unsupported(self) -> None: + if len(self._tokens) <= 1: + return + + # We use _find_sql because self.sql may comprise multiple chunks, and we're only + # interested in emitting a warning for the one being currently processed. + sql = self._find_sql(self._tokens[0], self._tokens[-1])[ + : self.error_message_context + ] + + logger.warning( + f"'{sql}' contains unsupported syntax. Falling back to parsing as a 'Command'." + ) + + def _parse_command(self) -> exp.Command: + self._warn_unsupported() + return self.expression( + exp.Command, + comments=self._prev_comments, + this=self._prev.text.upper(), + expression=self._parse_string(), + ) + + def _try_parse( + self, parse_method: t.Callable[[], T], retreat: bool = False + ) -> t.Optional[T]: + """ + Attemps to backtrack if a parse function that contains a try/catch internally raises an error. + This behavior can be different depending on the uset-set ErrorLevel, so _try_parse aims to + solve this by setting & resetting the parser state accordingly + """ + index = self._index + error_level = self.error_level + + self.error_level = ErrorLevel.IMMEDIATE + try: + this = parse_method() + except ParseError: + this = None + finally: + if not this or retreat: + self._retreat(index) + self.error_level = error_level + + return this + + def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: + start = self._prev + exists = self._parse_exists() if allow_exists else None + + self._match(TokenType.ON) + + materialized = self._match_text_seq("MATERIALIZED") + kind = self._match_set(self.CREATABLES) and self._prev + if not kind: + return self._parse_as_command(start) + + if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): + this = self._parse_user_defined_function(kind=kind.token_type) + elif kind.token_type == TokenType.TABLE: + this = self._parse_table(alias_tokens=self.COMMENT_TABLE_ALIAS_TOKENS) + elif kind.token_type == TokenType.COLUMN: + this = self._parse_column() + else: + this = self._parse_id_var() + + self._match(TokenType.IS) + + return self.expression( + exp.Comment, + this=this, + kind=kind.text, + expression=self._parse_string(), + exists=exists, + materialized=materialized, + ) + + def _parse_to_table( + self, + ) -> exp.ToTableProperty: + table = self._parse_table_parts(schema=True) + return self.expression(exp.ToTableProperty, this=table) + + # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl + def _parse_ttl(self) -> exp.Expression: + def _parse_ttl_action() -> t.Optional[exp.Expression]: + this = self._parse_bitwise() + + if self._match_text_seq("DELETE"): + return self.expression(exp.MergeTreeTTLAction, this=this, delete=True) + if self._match_text_seq("RECOMPRESS"): + return self.expression( + exp.MergeTreeTTLAction, this=this, recompress=self._parse_bitwise() + ) + if self._match_text_seq("TO", "DISK"): + return self.expression( + exp.MergeTreeTTLAction, this=this, to_disk=self._parse_string() + ) + if self._match_text_seq("TO", "VOLUME"): + return self.expression( + exp.MergeTreeTTLAction, this=this, to_volume=self._parse_string() + ) + + return this + + expressions = self._parse_csv(_parse_ttl_action) + where = self._parse_where() + group = self._parse_group() + + aggregates = None + if group and self._match(TokenType.SET): + aggregates = self._parse_csv(self._parse_set_item) + + return self.expression( + exp.MergeTreeTTL, + expressions=expressions, + where=where, + group=group, + aggregates=aggregates, + ) + + def _parse_statement(self) -> t.Optional[exp.Expression]: + if self._curr is None: + return None + + if self._match_set(self.STATEMENT_PARSERS): + comments = self._prev_comments + stmt = self.STATEMENT_PARSERS[self._prev.token_type](self) + stmt.add_comments(comments, prepend=True) + return stmt + + if self._match_set(self.dialect.tokenizer_class.COMMANDS): + return self._parse_command() + + expression = self._parse_expression() + expression = ( + self._parse_set_operations(expression) + if expression + else self._parse_select() + ) + return self._parse_query_modifiers(expression) + + def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command: + start = self._prev + temporary = self._match(TokenType.TEMPORARY) + materialized = self._match_text_seq("MATERIALIZED") + + kind = self._match_set(self.CREATABLES) and self._prev.text.upper() + if not kind: + return self._parse_as_command(start) + + concurrently = self._match_text_seq("CONCURRENTLY") + if_exists = exists or self._parse_exists() + + if kind == "COLUMN": + this = self._parse_column() + else: + this = self._parse_table_parts( + schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA + ) + + cluster = self._parse_on_property() if self._match(TokenType.ON) else None + + if self._match(TokenType.L_PAREN, advance=False): + expressions = self._parse_wrapped_csv(self._parse_types) + else: + expressions = None + + return self.expression( + exp.Drop, + exists=if_exists, + this=this, + expressions=expressions, + kind=self.dialect.CREATABLE_KIND_MAPPING.get(kind) or kind, + temporary=temporary, + materialized=materialized, + cascade=self._match_text_seq("CASCADE"), + constraints=self._match_text_seq("CONSTRAINTS"), + purge=self._match_text_seq("PURGE"), + cluster=cluster, + concurrently=concurrently, + ) + + def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: + return ( + self._match_text_seq("IF") + and (not not_ or self._match(TokenType.NOT)) + and self._match(TokenType.EXISTS) + ) + + def _parse_create(self) -> exp.Create | exp.Command: + # Note: this can't be None because we've matched a statement parser + start = self._prev + + replace = ( + start.token_type == TokenType.REPLACE + or self._match_pair(TokenType.OR, TokenType.REPLACE) + or self._match_pair(TokenType.OR, TokenType.ALTER) + ) + refresh = self._match_pair(TokenType.OR, TokenType.REFRESH) + + unique = self._match(TokenType.UNIQUE) + + if self._match_text_seq("CLUSTERED", "COLUMNSTORE"): + clustered = True + elif self._match_text_seq( + "NONCLUSTERED", "COLUMNSTORE" + ) or self._match_text_seq("COLUMNSTORE"): + clustered = False + else: + clustered = None + + if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): + self._advance() + + properties = None + create_token = self._match_set(self.CREATABLES) and self._prev + + if not create_token: + # exp.Properties.Location.POST_CREATE + properties = self._parse_properties() + create_token = self._match_set(self.CREATABLES) and self._prev + + if not properties or not create_token: + return self._parse_as_command(start) + + concurrently = self._match_text_seq("CONCURRENTLY") + exists = self._parse_exists(not_=True) + this = None + expression: t.Optional[exp.Expression] = None + indexes = None + no_schema_binding = None + begin = None + end = None + clone = None + + def extend_props(temp_props: t.Optional[exp.Properties]) -> None: + nonlocal properties + if properties and temp_props: + properties.expressions.extend(temp_props.expressions) + elif temp_props: + properties = temp_props + + if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): + this = self._parse_user_defined_function(kind=create_token.token_type) + + # exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature) + extend_props(self._parse_properties()) + + expression = self._match(TokenType.ALIAS) and self._parse_heredoc() + extend_props(self._parse_properties()) + + if not expression: + if self._match(TokenType.COMMAND): + expression = self._parse_as_command(self._prev) + else: + begin = self._match(TokenType.BEGIN) + return_ = self._match_text_seq("RETURN") + + if self._match(TokenType.STRING, advance=False): + # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property + # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement + expression = self._parse_string() + extend_props(self._parse_properties()) + else: + expression = self._parse_user_defined_function_expression() + + end = self._match_text_seq("END") + + if return_: + expression = self.expression(exp.Return, this=expression) + elif create_token.token_type == TokenType.INDEX: + # Postgres allows anonymous indexes, eg. CREATE INDEX IF NOT EXISTS ON t(c) + if not self._match(TokenType.ON): + index = self._parse_id_var() + anonymous = False + else: + index = None + anonymous = True + + this = self._parse_index(index=index, anonymous=anonymous) + elif create_token.token_type in self.DB_CREATABLES: + table_parts = self._parse_table_parts( + schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA + ) + + # exp.Properties.Location.POST_NAME + self._match(TokenType.COMMA) + extend_props(self._parse_properties(before=True)) + + this = self._parse_schema(this=table_parts) + + # exp.Properties.Location.POST_SCHEMA and POST_WITH + extend_props(self._parse_properties()) + + has_alias = self._match(TokenType.ALIAS) + if not self._match_set(self.DDL_SELECT_TOKENS, advance=False): + # exp.Properties.Location.POST_ALIAS + extend_props(self._parse_properties()) + + if create_token.token_type == TokenType.SEQUENCE: + expression = self._parse_types() + props = self._parse_properties() + if props: + sequence_props = exp.SequenceProperties() + options = [] + for prop in props: + if isinstance(prop, exp.SequenceProperties): + for arg, value in prop.args.items(): + if arg == "options": + options.extend(value) + else: + sequence_props.set(arg, value) + prop.pop() + + if options: + sequence_props.set("options", options) + + props.append("expressions", sequence_props) + extend_props(props) + else: + expression = self._parse_ddl_select() + + # Some dialects also support using a table as an alias instead of a SELECT. + # Here we fallback to this as an alternative. + if not expression and has_alias: + expression = self._try_parse(self._parse_table_parts) + + if create_token.token_type == TokenType.TABLE: + # exp.Properties.Location.POST_EXPRESSION + extend_props(self._parse_properties()) + + indexes = [] + while True: + index = self._parse_index() + + # exp.Properties.Location.POST_INDEX + extend_props(self._parse_properties()) + if not index: + break + else: + self._match(TokenType.COMMA) + indexes.append(index) + elif create_token.token_type == TokenType.VIEW: + if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"): + no_schema_binding = True + elif create_token.token_type in (TokenType.SINK, TokenType.SOURCE): + extend_props(self._parse_properties()) + + shallow = self._match_text_seq("SHALLOW") + + if self._match_texts(self.CLONE_KEYWORDS): + copy = self._prev.text.lower() == "copy" + clone = self.expression( + exp.Clone, + this=self._parse_table(schema=True), + shallow=shallow, + copy=copy, + ) + + if self._curr and not self._match_set( + (TokenType.R_PAREN, TokenType.COMMA), advance=False + ): + return self._parse_as_command(start) + + create_kind_text = create_token.text.upper() + return self.expression( + exp.Create, + this=this, + kind=self.dialect.CREATABLE_KIND_MAPPING.get(create_kind_text) + or create_kind_text, + replace=replace, + refresh=refresh, + unique=unique, + expression=expression, + exists=exists, + properties=properties, + indexes=indexes, + no_schema_binding=no_schema_binding, + begin=begin, + end=end, + clone=clone, + concurrently=concurrently, + clustered=clustered, + ) + + def _parse_sequence_properties(self) -> t.Optional[exp.SequenceProperties]: + seq = exp.SequenceProperties() + + options = [] + index = self._index + + while self._curr: + self._match(TokenType.COMMA) + if self._match_text_seq("INCREMENT"): + self._match_text_seq("BY") + self._match_text_seq("=") + seq.set("increment", self._parse_term()) + elif self._match_text_seq("MINVALUE"): + seq.set("minvalue", self._parse_term()) + elif self._match_text_seq("MAXVALUE"): + seq.set("maxvalue", self._parse_term()) + elif self._match(TokenType.START_WITH) or self._match_text_seq("START"): + self._match_text_seq("=") + seq.set("start", self._parse_term()) + elif self._match_text_seq("CACHE"): + # T-SQL allows empty CACHE which is initialized dynamically + seq.set("cache", self._parse_number() or True) + elif self._match_text_seq("OWNED", "BY"): + # "OWNED BY NONE" is the default + seq.set( + "owned", + None if self._match_text_seq("NONE") else self._parse_column(), + ) + else: + opt = self._parse_var_from_options( + self.CREATE_SEQUENCE, raise_unmatched=False + ) + if opt: + options.append(opt) + else: + break + + seq.set("options", options if options else None) + return None if self._index == index else seq + + def _parse_property_before(self) -> t.Optional[exp.Expression]: + # only used for teradata currently + self._match(TokenType.COMMA) + + kwargs = { + "no": self._match_text_seq("NO"), + "dual": self._match_text_seq("DUAL"), + "before": self._match_text_seq("BEFORE"), + "default": self._match_text_seq("DEFAULT"), + "local": (self._match_text_seq("LOCAL") and "LOCAL") + or (self._match_text_seq("NOT", "LOCAL") and "NOT LOCAL"), + "after": self._match_text_seq("AFTER"), + "minimum": self._match_texts(("MIN", "MINIMUM")), + "maximum": self._match_texts(("MAX", "MAXIMUM")), + } + + if self._match_texts(self.PROPERTY_PARSERS): + parser = self.PROPERTY_PARSERS[self._prev.text.upper()] + try: + return parser(self, **{k: v for k, v in kwargs.items() if v}) + except TypeError: + self.raise_error(f"Cannot parse property '{self._prev.text}'") + + return None + + def _parse_wrapped_properties(self) -> t.List[exp.Expression]: + return self._parse_wrapped_csv(self._parse_property) + + def _parse_property(self) -> t.Optional[exp.Expression]: + if self._match_texts(self.PROPERTY_PARSERS): + return self.PROPERTY_PARSERS[self._prev.text.upper()](self) + + if self._match(TokenType.DEFAULT) and self._match_texts(self.PROPERTY_PARSERS): + return self.PROPERTY_PARSERS[self._prev.text.upper()](self, default=True) + + if self._match_text_seq("COMPOUND", "SORTKEY"): + return self._parse_sortkey(compound=True) + + if self._match_text_seq("SQL", "SECURITY"): + return self.expression( + exp.SqlSecurityProperty, + this=self._match_texts(("DEFINER", "INVOKER")) + and self._prev.text.upper(), + ) + + index = self._index + + seq_props = self._parse_sequence_properties() + if seq_props: + return seq_props + + self._retreat(index) + key = self._parse_column() + + if not self._match(TokenType.EQ): + self._retreat(index) + return None + + # Transform the key to exp.Dot if it's dotted identifiers wrapped in exp.Column or to exp.Var otherwise + if isinstance(key, exp.Column): + key = key.to_dot() if len(key.parts) > 1 else exp.var(key.name) + + value = self._parse_bitwise() or self._parse_var(any_token=True) + + # Transform the value to exp.Var if it was parsed as exp.Column(exp.Identifier()) + if isinstance(value, exp.Column): + value = exp.var(value.name) + + return self.expression(exp.Property, this=key, value=value) + + def _parse_stored( + self, + ) -> t.Union[exp.FileFormatProperty, exp.StorageHandlerProperty]: + if self._match_text_seq("BY"): + return self.expression( + exp.StorageHandlerProperty, this=self._parse_var_or_string() + ) + + self._match(TokenType.ALIAS) + input_format = ( + self._parse_string() if self._match_text_seq("INPUTFORMAT") else None + ) + output_format = ( + self._parse_string() if self._match_text_seq("OUTPUTFORMAT") else None + ) + + return self.expression( + exp.FileFormatProperty, + this=( + self.expression( + exp.InputOutputFormat, + input_format=input_format, + output_format=output_format, + ) + if input_format or output_format + else self._parse_var_or_string() + or self._parse_number() + or self._parse_id_var() + ), + hive_format=True, + ) + + def _parse_unquoted_field(self) -> t.Optional[exp.Expression]: + field = self._parse_field() + if isinstance(field, exp.Identifier) and not field.quoted: + field = exp.var(field) + + return field + + def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E: + self._match(TokenType.EQ) + self._match(TokenType.ALIAS) + + return self.expression(exp_class, this=self._parse_unquoted_field(), **kwargs) + + def _parse_properties( + self, before: t.Optional[bool] = None + ) -> t.Optional[exp.Properties]: + properties = [] + while True: + if before: + prop = self._parse_property_before() + else: + prop = self._parse_property() + if not prop: + break + for p in ensure_list(prop): + properties.append(p) + + if properties: + return self.expression(exp.Properties, expressions=properties) + + return None + + def _parse_fallback(self, no: bool = False) -> exp.FallbackProperty: + return self.expression( + exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") + ) + + def _parse_security(self) -> t.Optional[exp.SecurityProperty]: + if self._match_texts(("NONE", "DEFINER", "INVOKER")): + security_specifier = self._prev.text.upper() + return self.expression(exp.SecurityProperty, this=security_specifier) + return None + + def _parse_settings_property(self) -> exp.SettingsProperty: + return self.expression( + exp.SettingsProperty, expressions=self._parse_csv(self._parse_assignment) + ) + + def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProperty: + if self._index >= 2: + pre_volatile_token = self._tokens[self._index - 2] + else: + pre_volatile_token = None + + if ( + pre_volatile_token + and pre_volatile_token.token_type in self.PRE_VOLATILE_TOKENS + ): + return exp.VolatileProperty() + + return self.expression( + exp.StabilityProperty, this=exp.Literal.string("VOLATILE") + ) + + def _parse_retention_period(self) -> exp.Var: + # Parse TSQL's HISTORY_RETENTION_PERIOD: {INFINITE | DAY | DAYS | MONTH ...} + number = self._parse_number() + number_str = f"{number} " if number else "" + unit = self._parse_var(any_token=True) + return exp.var(f"{number_str}{unit}") + + def _parse_system_versioning_property( + self, with_: bool = False + ) -> exp.WithSystemVersioningProperty: + self._match(TokenType.EQ) + prop = self.expression( + exp.WithSystemVersioningProperty, + on=True, + with_=with_, + ) + + if self._match_text_seq("OFF"): + prop.set("on", False) + return prop + + self._match(TokenType.ON) + if self._match(TokenType.L_PAREN): + while self._curr and not self._match(TokenType.R_PAREN): + if self._match_text_seq("HISTORY_TABLE", "="): + prop.set("this", self._parse_table_parts()) + elif self._match_text_seq("DATA_CONSISTENCY_CHECK", "="): + prop.set( + "data_consistency", + self._advance_any() and self._prev.text.upper(), + ) + elif self._match_text_seq("HISTORY_RETENTION_PERIOD", "="): + prop.set("retention_period", self._parse_retention_period()) + + self._match(TokenType.COMMA) + + return prop + + def _parse_data_deletion_property(self) -> exp.DataDeletionProperty: + self._match(TokenType.EQ) + on = self._match_text_seq("ON") or not self._match_text_seq("OFF") + prop = self.expression(exp.DataDeletionProperty, on=on) + + if self._match(TokenType.L_PAREN): + while self._curr and not self._match(TokenType.R_PAREN): + if self._match_text_seq("FILTER_COLUMN", "="): + prop.set("filter_column", self._parse_column()) + elif self._match_text_seq("RETENTION_PERIOD", "="): + prop.set("retention_period", self._parse_retention_period()) + + self._match(TokenType.COMMA) + + return prop + + def _parse_distributed_property(self) -> exp.DistributedByProperty: + kind = "HASH" + expressions: t.Optional[t.List[exp.Expression]] = None + if self._match_text_seq("BY", "HASH"): + expressions = self._parse_wrapped_csv(self._parse_id_var) + elif self._match_text_seq("BY", "RANDOM"): + kind = "RANDOM" + + # If the BUCKETS keyword is not present, the number of buckets is AUTO + buckets: t.Optional[exp.Expression] = None + if self._match_text_seq("BUCKETS") and not self._match_text_seq("AUTO"): + buckets = self._parse_number() + + return self.expression( + exp.DistributedByProperty, + expressions=expressions, + kind=kind, + buckets=buckets, + order=self._parse_order(), + ) + + def _parse_composite_key_property(self, expr_type: t.Type[E]) -> E: + self._match_text_seq("KEY") + expressions = self._parse_wrapped_id_vars() + return self.expression(expr_type, expressions=expressions) + + def _parse_with_property( + self, + ) -> t.Optional[exp.Expression] | t.List[exp.Expression]: + if self._match_text_seq("(", "SYSTEM_VERSIONING"): + prop = self._parse_system_versioning_property(with_=True) + self._match_r_paren() + return prop + + if self._match(TokenType.L_PAREN, advance=False): + return self._parse_wrapped_properties() + + if self._match_text_seq("JOURNAL"): + return self._parse_withjournaltable() + + if self._match_texts(self.VIEW_ATTRIBUTES): + return self.expression( + exp.ViewAttributeProperty, this=self._prev.text.upper() + ) + + if self._match_text_seq("DATA"): + return self._parse_withdata(no=False) + elif self._match_text_seq("NO", "DATA"): + return self._parse_withdata(no=True) + + if self._match(TokenType.SERDE_PROPERTIES, advance=False): + return self._parse_serde_properties(with_=True) + + if self._match(TokenType.SCHEMA): + return self.expression( + exp.WithSchemaBindingProperty, + this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS), + ) + + if self._match_texts(self.PROCEDURE_OPTIONS, advance=False): + return self.expression( + exp.WithProcedureOptions, + expressions=self._parse_csv(self._parse_procedure_option), + ) + + if not self._next: + return None + + return self._parse_withisolatedloading() + + def _parse_procedure_option(self) -> exp.Expression | None: + if self._match_text_seq("EXECUTE", "AS"): + return self.expression( + exp.ExecuteAsProperty, + this=self._parse_var_from_options( + self.EXECUTE_AS_OPTIONS, raise_unmatched=False + ) + or self._parse_string(), + ) + + return self._parse_var_from_options(self.PROCEDURE_OPTIONS) + + # https://dev.mysql.com/doc/refman/8.0/en/create-view.html + def _parse_definer(self) -> t.Optional[exp.DefinerProperty]: + self._match(TokenType.EQ) + + user = self._parse_id_var() + self._match(TokenType.PARAMETER) + host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text) + + if not user or not host: + return None + + return exp.DefinerProperty(this=f"{user}@{host}") + + def _parse_withjournaltable(self) -> exp.WithJournalTableProperty: + self._match(TokenType.TABLE) + self._match(TokenType.EQ) + return self.expression( + exp.WithJournalTableProperty, this=self._parse_table_parts() + ) + + def _parse_log(self, no: bool = False) -> exp.LogProperty: + return self.expression(exp.LogProperty, no=no) + + def _parse_journal(self, **kwargs) -> exp.JournalProperty: + return self.expression(exp.JournalProperty, **kwargs) + + def _parse_checksum(self) -> exp.ChecksumProperty: + self._match(TokenType.EQ) + + on = None + if self._match(TokenType.ON): + on = True + elif self._match_text_seq("OFF"): + on = False + + return self.expression( + exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT) + ) + + def _parse_cluster(self, wrapped: bool = False) -> exp.Cluster: + return self.expression( + exp.Cluster, + expressions=( + self._parse_wrapped_csv(self._parse_ordered) + if wrapped + else self._parse_csv(self._parse_ordered) + ), + ) + + def _parse_clustered_by(self) -> exp.ClusteredByProperty: + self._match_text_seq("BY") + + self._match_l_paren() + expressions = self._parse_csv(self._parse_column) + self._match_r_paren() + + if self._match_text_seq("SORTED", "BY"): + self._match_l_paren() + sorted_by = self._parse_csv(self._parse_ordered) + self._match_r_paren() + else: + sorted_by = None + + self._match(TokenType.INTO) + buckets = self._parse_number() + self._match_text_seq("BUCKETS") + + return self.expression( + exp.ClusteredByProperty, + expressions=expressions, + sorted_by=sorted_by, + buckets=buckets, + ) + + def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]: + if not self._match_text_seq("GRANTS"): + self._retreat(self._index - 1) + return None + + return self.expression(exp.CopyGrantsProperty) + + def _parse_freespace(self) -> exp.FreespaceProperty: + self._match(TokenType.EQ) + return self.expression( + exp.FreespaceProperty, + this=self._parse_number(), + percent=self._match(TokenType.PERCENT), + ) + + def _parse_mergeblockratio( + self, no: bool = False, default: bool = False + ) -> exp.MergeBlockRatioProperty: + if self._match(TokenType.EQ): + return self.expression( + exp.MergeBlockRatioProperty, + this=self._parse_number(), + percent=self._match(TokenType.PERCENT), + ) + + return self.expression(exp.MergeBlockRatioProperty, no=no, default=default) + + def _parse_datablocksize( + self, + default: t.Optional[bool] = None, + minimum: t.Optional[bool] = None, + maximum: t.Optional[bool] = None, + ) -> exp.DataBlocksizeProperty: + self._match(TokenType.EQ) + size = self._parse_number() + + units = None + if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")): + units = self._prev.text + + return self.expression( + exp.DataBlocksizeProperty, + size=size, + units=units, + default=default, + minimum=minimum, + maximum=maximum, + ) + + def _parse_blockcompression(self) -> exp.BlockCompressionProperty: + self._match(TokenType.EQ) + always = self._match_text_seq("ALWAYS") + manual = self._match_text_seq("MANUAL") + never = self._match_text_seq("NEVER") + default = self._match_text_seq("DEFAULT") + + autotemp = None + if self._match_text_seq("AUTOTEMP"): + autotemp = self._parse_schema() + + return self.expression( + exp.BlockCompressionProperty, + always=always, + manual=manual, + never=never, + default=default, + autotemp=autotemp, + ) + + def _parse_withisolatedloading(self) -> t.Optional[exp.IsolatedLoadingProperty]: + index = self._index + no = self._match_text_seq("NO") + concurrent = self._match_text_seq("CONCURRENT") + + if not self._match_text_seq("ISOLATED", "LOADING"): + self._retreat(index) + return None + + target = self._parse_var_from_options( + self.ISOLATED_LOADING_OPTIONS, raise_unmatched=False + ) + return self.expression( + exp.IsolatedLoadingProperty, no=no, concurrent=concurrent, target=target + ) + + def _parse_locking(self) -> exp.LockingProperty: + if self._match(TokenType.TABLE): + kind = "TABLE" + elif self._match(TokenType.VIEW): + kind = "VIEW" + elif self._match(TokenType.ROW): + kind = "ROW" + elif self._match_text_seq("DATABASE"): + kind = "DATABASE" + else: + kind = None + + if kind in ("DATABASE", "TABLE", "VIEW"): + this = self._parse_table_parts() + else: + this = None + + if self._match(TokenType.FOR): + for_or_in = "FOR" + elif self._match(TokenType.IN): + for_or_in = "IN" + else: + for_or_in = None + + if self._match_text_seq("ACCESS"): + lock_type = "ACCESS" + elif self._match_texts(("EXCL", "EXCLUSIVE")): + lock_type = "EXCLUSIVE" + elif self._match_text_seq("SHARE"): + lock_type = "SHARE" + elif self._match_text_seq("READ"): + lock_type = "READ" + elif self._match_text_seq("WRITE"): + lock_type = "WRITE" + elif self._match_text_seq("CHECKSUM"): + lock_type = "CHECKSUM" + else: + lock_type = None + + override = self._match_text_seq("OVERRIDE") + + return self.expression( + exp.LockingProperty, + this=this, + kind=kind, + for_or_in=for_or_in, + lock_type=lock_type, + override=override, + ) + + def _parse_partition_by(self) -> t.List[exp.Expression]: + if self._match(TokenType.PARTITION_BY): + return self._parse_csv(self._parse_disjunction) + return [] + + def _parse_partition_bound_spec(self) -> exp.PartitionBoundSpec: + def _parse_partition_bound_expr() -> t.Optional[exp.Expression]: + if self._match_text_seq("MINVALUE"): + return exp.var("MINVALUE") + if self._match_text_seq("MAXVALUE"): + return exp.var("MAXVALUE") + return self._parse_bitwise() + + this: t.Optional[exp.Expression | t.List[exp.Expression]] = None + expression = None + from_expressions = None + to_expressions = None + + if self._match(TokenType.IN): + this = self._parse_wrapped_csv(self._parse_bitwise) + elif self._match(TokenType.FROM): + from_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr) + self._match_text_seq("TO") + to_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr) + elif self._match_text_seq("WITH", "(", "MODULUS"): + this = self._parse_number() + self._match_text_seq(",", "REMAINDER") + expression = self._parse_number() + self._match_r_paren() + else: + self.raise_error("Failed to parse partition bound spec.") + + return self.expression( + exp.PartitionBoundSpec, + this=this, + expression=expression, + from_expressions=from_expressions, + to_expressions=to_expressions, + ) + + # https://www.postgresql.org/docs/current/sql-createtable.html + def _parse_partitioned_of(self) -> t.Optional[exp.PartitionedOfProperty]: + if not self._match_text_seq("OF"): + self._retreat(self._index - 1) + return None + + this = self._parse_table(schema=True) + + if self._match(TokenType.DEFAULT): + expression: exp.Var | exp.PartitionBoundSpec = exp.var("DEFAULT") + elif self._match_text_seq("FOR", "VALUES"): + expression = self._parse_partition_bound_spec() + else: + self.raise_error("Expecting either DEFAULT or FOR VALUES clause.") + + return self.expression( + exp.PartitionedOfProperty, this=this, expression=expression + ) + + def _parse_partitioned_by(self) -> exp.PartitionedByProperty: + self._match(TokenType.EQ) + return self.expression( + exp.PartitionedByProperty, + this=self._parse_schema() or self._parse_bracket(self._parse_field()), + ) + + def _parse_withdata(self, no: bool = False) -> exp.WithDataProperty: + if self._match_text_seq("AND", "STATISTICS"): + statistics = True + elif self._match_text_seq("AND", "NO", "STATISTICS"): + statistics = False + else: + statistics = None + + return self.expression(exp.WithDataProperty, no=no, statistics=statistics) + + def _parse_contains_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL"): + return self.expression(exp.SqlReadWriteProperty, this="CONTAINS SQL") + return None + + def _parse_modifies_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL", "DATA"): + return self.expression(exp.SqlReadWriteProperty, this="MODIFIES SQL DATA") + return None + + def _parse_no_property(self) -> t.Optional[exp.Expression]: + if self._match_text_seq("PRIMARY", "INDEX"): + return exp.NoPrimaryIndexProperty() + if self._match_text_seq("SQL"): + return self.expression(exp.SqlReadWriteProperty, this="NO SQL") + return None + + def _parse_on_property(self) -> t.Optional[exp.Expression]: + if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"): + return exp.OnCommitProperty() + if self._match_text_seq("COMMIT", "DELETE", "ROWS"): + return exp.OnCommitProperty(delete=True) + return self.expression( + exp.OnProperty, this=self._parse_schema(self._parse_id_var()) + ) + + def _parse_reads_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL", "DATA"): + return self.expression(exp.SqlReadWriteProperty, this="READS SQL DATA") + return None + + def _parse_distkey(self) -> exp.DistKeyProperty: + return self.expression( + exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var) + ) + + def _parse_create_like(self) -> t.Optional[exp.LikeProperty]: + table = self._parse_table(schema=True) + + options = [] + while self._match_texts(("INCLUDING", "EXCLUDING")): + this = self._prev.text.upper() + + id_var = self._parse_id_var() + if not id_var: + return None + + options.append( + self.expression( + exp.Property, this=this, value=exp.var(id_var.this.upper()) + ) + ) + + return self.expression(exp.LikeProperty, this=table, expressions=options) + + def _parse_sortkey(self, compound: bool = False) -> exp.SortKeyProperty: + return self.expression( + exp.SortKeyProperty, this=self._parse_wrapped_id_vars(), compound=compound + ) + + def _parse_character_set(self, default: bool = False) -> exp.CharacterSetProperty: + self._match(TokenType.EQ) + return self.expression( + exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default + ) + + def _parse_remote_with_connection(self) -> exp.RemoteWithConnectionModelProperty: + self._match_text_seq("WITH", "CONNECTION") + return self.expression( + exp.RemoteWithConnectionModelProperty, this=self._parse_table_parts() + ) + + def _parse_returns(self) -> exp.ReturnsProperty: + value: t.Optional[exp.Expression] + null = None + is_table = self._match(TokenType.TABLE) + + if is_table: + if self._match(TokenType.LT): + value = self.expression( + exp.Schema, + this="TABLE", + expressions=self._parse_csv(self._parse_struct_types), + ) + if not self._match(TokenType.GT): + self.raise_error("Expecting >") + else: + value = self._parse_schema(exp.var("TABLE")) + elif self._match_text_seq("NULL", "ON", "NULL", "INPUT"): + null = True + value = None + else: + value = self._parse_types() + + return self.expression( + exp.ReturnsProperty, this=value, is_table=is_table, null=null + ) + + def _parse_describe(self) -> exp.Describe: + kind = self._match_set(self.CREATABLES) and self._prev.text + style = self._match_texts(self.DESCRIBE_STYLES) and self._prev.text.upper() + if self._match(TokenType.DOT): + style = None + self._retreat(self._index - 2) + + format = ( + self._parse_property() + if self._match(TokenType.FORMAT, advance=False) + else None + ) + + if self._match_set(self.STATEMENT_PARSERS, advance=False): + this = self._parse_statement() + else: + this = self._parse_table(schema=True) + + properties = self._parse_properties() + expressions = properties.expressions if properties else None + partition = self._parse_partition() + return self.expression( + exp.Describe, + this=this, + style=style, + kind=kind, + expressions=expressions, + partition=partition, + format=format, + ) + + def _parse_multitable_inserts( + self, comments: t.Optional[t.List[str]] + ) -> exp.MultitableInserts: + kind = self._prev.text.upper() + expressions = [] + + def parse_conditional_insert() -> t.Optional[exp.ConditionalInsert]: + if self._match(TokenType.WHEN): + expression = self._parse_disjunction() + self._match(TokenType.THEN) + else: + expression = None + + else_ = self._match(TokenType.ELSE) + + if not self._match(TokenType.INTO): + return None + + return self.expression( + exp.ConditionalInsert, + this=self.expression( + exp.Insert, + this=self._parse_table(schema=True), + expression=self._parse_derived_table_values(), + ), + expression=expression, + else_=else_, + ) + + expression = parse_conditional_insert() + while expression is not None: + expressions.append(expression) + expression = parse_conditional_insert() + + return self.expression( + exp.MultitableInserts, + kind=kind, + comments=comments, + expressions=expressions, + source=self._parse_table(), + ) + + def _parse_insert(self) -> t.Union[exp.Insert, exp.MultitableInserts]: + comments = [] + hint = self._parse_hint() + overwrite = self._match(TokenType.OVERWRITE) + ignore = self._match(TokenType.IGNORE) + local = self._match_text_seq("LOCAL") + alternative = None + is_function = None + + if self._match_text_seq("DIRECTORY"): + this: t.Optional[exp.Expression] = self.expression( + exp.Directory, + this=self._parse_var_or_string(), + local=local, + row_format=self._parse_row_format(match_row=True), + ) + else: + if self._match_set((TokenType.FIRST, TokenType.ALL)): + comments += ensure_list(self._prev_comments) + return self._parse_multitable_inserts(comments) + + if self._match(TokenType.OR): + alternative = ( + self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text + ) + + self._match(TokenType.INTO) + comments += ensure_list(self._prev_comments) + self._match(TokenType.TABLE) + is_function = self._match(TokenType.FUNCTION) + + this = self._parse_function() if is_function else self._parse_insert_table() + + returning = self._parse_returning() # TSQL allows RETURNING before source + + return self.expression( + exp.Insert, + comments=comments, + hint=hint, + is_function=is_function, + this=this, + stored=self._match_text_seq("STORED") and self._parse_stored(), + by_name=self._match_text_seq("BY", "NAME"), + exists=self._parse_exists(), + where=self._match_pair(TokenType.REPLACE, TokenType.WHERE) + and self._parse_disjunction(), + partition=self._match(TokenType.PARTITION_BY) + and self._parse_partitioned_by(), + settings=self._match_text_seq("SETTINGS") + and self._parse_settings_property(), + default=self._match_text_seq("DEFAULT", "VALUES"), + expression=self._parse_derived_table_values() or self._parse_ddl_select(), + conflict=self._parse_on_conflict(), + returning=returning or self._parse_returning(), + overwrite=overwrite, + alternative=alternative, + ignore=ignore, + source=self._match(TokenType.TABLE) and self._parse_table(), + ) + + def _parse_insert_table(self) -> t.Optional[exp.Expression]: + this = self._parse_table(schema=True, parse_partition=True) + if isinstance(this, exp.Table) and self._match(TokenType.ALIAS, advance=False): + this.set("alias", self._parse_table_alias()) + return this + + def _parse_kill(self) -> exp.Kill: + kind = ( + exp.var(self._prev.text) + if self._match_texts(("CONNECTION", "QUERY")) + else None + ) + + return self.expression( + exp.Kill, + this=self._parse_primary(), + kind=kind, + ) + + def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]: + conflict = self._match_text_seq("ON", "CONFLICT") + duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY") + + if not conflict and not duplicate: + return None + + conflict_keys = None + constraint = None + + if conflict: + if self._match_text_seq("ON", "CONSTRAINT"): + constraint = self._parse_id_var() + elif self._match(TokenType.L_PAREN): + conflict_keys = self._parse_csv(self._parse_id_var) + self._match_r_paren() + + action = self._parse_var_from_options(self.CONFLICT_ACTIONS) + if self._prev.token_type == TokenType.UPDATE: + self._match(TokenType.SET) + expressions = self._parse_csv(self._parse_equality) + else: + expressions = None + + return self.expression( + exp.OnConflict, + duplicate=duplicate, + expressions=expressions, + action=action, + conflict_keys=conflict_keys, + constraint=constraint, + where=self._parse_where(), + ) + + def _parse_returning(self) -> t.Optional[exp.Returning]: + if not self._match(TokenType.RETURNING): + return None + return self.expression( + exp.Returning, + expressions=self._parse_csv(self._parse_expression), + into=self._match(TokenType.INTO) and self._parse_table_part(), + ) + + def _parse_row( + self, + ) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: + if not self._match(TokenType.FORMAT): + return None + return self._parse_row_format() + + def _parse_serde_properties( + self, with_: bool = False + ) -> t.Optional[exp.SerdeProperties]: + index = self._index + with_ = with_ or self._match_text_seq("WITH") + + if not self._match(TokenType.SERDE_PROPERTIES): + self._retreat(index) + return None + return self.expression( + exp.SerdeProperties, + expressions=self._parse_wrapped_properties(), + with_=with_, + ) + + def _parse_row_format( + self, match_row: bool = False + ) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: + if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): + return None + + if self._match_text_seq("SERDE"): + this = self._parse_string() + + serde_properties = self._parse_serde_properties() + + return self.expression( + exp.RowFormatSerdeProperty, this=this, serde_properties=serde_properties + ) + + self._match_text_seq("DELIMITED") + + kwargs = {} + + if self._match_text_seq("FIELDS", "TERMINATED", "BY"): + kwargs["fields"] = self._parse_string() + if self._match_text_seq("ESCAPED", "BY"): + kwargs["escaped"] = self._parse_string() + if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"): + kwargs["collection_items"] = self._parse_string() + if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"): + kwargs["map_keys"] = self._parse_string() + if self._match_text_seq("LINES", "TERMINATED", "BY"): + kwargs["lines"] = self._parse_string() + if self._match_text_seq("NULL", "DEFINED", "AS"): + kwargs["null"] = self._parse_string() + + return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore + + def _parse_load(self) -> exp.LoadData | exp.Command: + if self._match_text_seq("DATA"): + local = self._match_text_seq("LOCAL") + self._match_text_seq("INPATH") + inpath = self._parse_string() + overwrite = self._match(TokenType.OVERWRITE) + self._match_pair(TokenType.INTO, TokenType.TABLE) + + return self.expression( + exp.LoadData, + this=self._parse_table(schema=True), + local=local, + overwrite=overwrite, + inpath=inpath, + partition=self._parse_partition(), + input_format=self._match_text_seq("INPUTFORMAT") + and self._parse_string(), + serde=self._match_text_seq("SERDE") and self._parse_string(), + ) + return self._parse_as_command(self._prev) + + def _parse_delete(self) -> exp.Delete: + # This handles MySQL's "Multiple-Table Syntax" + # https://dev.mysql.com/doc/refman/8.0/en/delete.html + tables = None + if not self._match(TokenType.FROM, advance=False): + tables = self._parse_csv(self._parse_table) or None + + returning = self._parse_returning() + + return self.expression( + exp.Delete, + tables=tables, + this=self._match(TokenType.FROM) and self._parse_table(joins=True), + using=self._match(TokenType.USING) + and self._parse_csv(lambda: self._parse_table(joins=True)), + cluster=self._match(TokenType.ON) and self._parse_on_property(), + where=self._parse_where(), + returning=returning or self._parse_returning(), + order=self._parse_order(), + limit=self._parse_limit(), + ) + + def _parse_update(self) -> exp.Update: + kwargs: t.Dict[str, t.Any] = { + "this": self._parse_table( + joins=True, alias_tokens=self.UPDATE_ALIAS_TOKENS + ), + } + while self._curr: + if self._match(TokenType.SET): + kwargs["expressions"] = self._parse_csv(self._parse_equality) + elif self._match(TokenType.RETURNING, advance=False): + kwargs["returning"] = self._parse_returning() + elif self._match(TokenType.FROM, advance=False): + kwargs["from_"] = self._parse_from(joins=True) + elif self._match(TokenType.WHERE, advance=False): + kwargs["where"] = self._parse_where() + elif self._match(TokenType.ORDER_BY, advance=False): + kwargs["order"] = self._parse_order() + elif self._match(TokenType.LIMIT, advance=False): + kwargs["limit"] = self._parse_limit() + else: + break + + return self.expression(exp.Update, **kwargs) + + def _parse_use(self) -> exp.Use: + return self.expression( + exp.Use, + kind=self._parse_var_from_options(self.USABLES, raise_unmatched=False), + this=self._parse_table(schema=False), + ) + + def _parse_uncache(self) -> exp.Uncache: + if not self._match(TokenType.TABLE): + self.raise_error("Expecting TABLE after UNCACHE") + + return self.expression( + exp.Uncache, + exists=self._parse_exists(), + this=self._parse_table(schema=True), + ) + + def _parse_cache(self) -> exp.Cache: + lazy = self._match_text_seq("LAZY") + self._match(TokenType.TABLE) + table = self._parse_table(schema=True) + + options = [] + if self._match_text_seq("OPTIONS"): + self._match_l_paren() + k = self._parse_string() + self._match(TokenType.EQ) + v = self._parse_string() + options = [k, v] + self._match_r_paren() + + self._match(TokenType.ALIAS) + return self.expression( + exp.Cache, + this=table, + lazy=lazy, + options=options, + expression=self._parse_select(nested=True), + ) + + def _parse_partition(self) -> t.Optional[exp.Partition]: + if not self._match_texts(self.PARTITION_KEYWORDS): + return None + + return self.expression( + exp.Partition, + subpartition=self._prev.text.upper() == "SUBPARTITION", + expressions=self._parse_wrapped_csv(self._parse_disjunction), + ) + + def _parse_value(self, values: bool = True) -> t.Optional[exp.Tuple]: + def _parse_value_expression() -> t.Optional[exp.Expression]: + if self.dialect.SUPPORTS_VALUES_DEFAULT and self._match(TokenType.DEFAULT): + return exp.var(self._prev.text.upper()) + return self._parse_expression() + + if self._match(TokenType.L_PAREN): + expressions = self._parse_csv(_parse_value_expression) + self._match_r_paren() + return self.expression(exp.Tuple, expressions=expressions) + + # In some dialects we can have VALUES 1, 2 which results in 1 column & 2 rows. + expression = self._parse_expression() + if expression: + return self.expression(exp.Tuple, expressions=[expression]) + return None + + def _parse_projections(self) -> t.List[exp.Expression]: + return self._parse_expressions() + + def _parse_wrapped_select(self, table: bool = False) -> t.Optional[exp.Expression]: + if self._match_set((TokenType.PIVOT, TokenType.UNPIVOT)): + this: t.Optional[exp.Expression] = self._parse_simplified_pivot( + is_unpivot=self._prev.token_type == TokenType.UNPIVOT + ) + elif self._match(TokenType.FROM): + from_ = self._parse_from(skip_from_token=True, consume_pipe=True) + # Support parentheses for duckdb FROM-first syntax + select = self._parse_select(from_=from_) + if select: + if not select.args.get("from_"): + select.set("from_", from_) + this = select + else: + this = exp.select("*").from_(t.cast(exp.From, from_)) + this = self._parse_query_modifiers(self._parse_set_operations(this)) + else: + this = ( + self._parse_table(consume_pipe=True) + if table + else self._parse_select(nested=True, parse_set_operation=False) + ) + + # Transform exp.Values into a exp.Table to pass through parse_query_modifiers + # in case a modifier (e.g. join) is following + if table and isinstance(this, exp.Values) and this.alias: + alias = this.args["alias"].pop() + this = exp.Table(this=this, alias=alias) + + this = self._parse_query_modifiers(self._parse_set_operations(this)) + + return this + + def _parse_select( + self, + nested: bool = False, + table: bool = False, + parse_subquery_alias: bool = True, + parse_set_operation: bool = True, + consume_pipe: bool = True, + from_: t.Optional[exp.From] = None, + ) -> t.Optional[exp.Expression]: + query = self._parse_select_query( + nested=nested, + table=table, + parse_subquery_alias=parse_subquery_alias, + parse_set_operation=parse_set_operation, + ) + + if consume_pipe and self._match(TokenType.PIPE_GT, advance=False): + if not query and from_: + query = exp.select("*").from_(from_) + if isinstance(query, exp.Query): + query = self._parse_pipe_syntax_query(query) + query = query.subquery(copy=False) if query and table else query + + return query + + def _parse_select_query( + self, + nested: bool = False, + table: bool = False, + parse_subquery_alias: bool = True, + parse_set_operation: bool = True, + ) -> t.Optional[exp.Expression]: + cte = self._parse_with() + + if cte: + this = self._parse_statement() + + if not this: + self.raise_error("Failed to parse any statement following CTE") + return cte + + while isinstance(this, exp.Subquery) and this.is_wrapper: + this = this.this + + if "with_" in this.arg_types: + this.set("with_", cte) + else: + self.raise_error(f"{this.key} does not support CTE") + this = cte + + return this + + # duckdb supports leading with FROM x + from_ = ( + self._parse_from(joins=True, consume_pipe=True) + if self._match(TokenType.FROM, advance=False) + else None + ) + + if self._match(TokenType.SELECT): + comments = self._prev_comments + + hint = self._parse_hint() + + if self._next and not self._next.token_type == TokenType.DOT: + all_ = self._match(TokenType.ALL) + distinct = self._match_set(self.DISTINCT_TOKENS) + else: + all_, distinct = None, None + + kind = ( + self._match(TokenType.ALIAS) + and self._match_texts(("STRUCT", "VALUE")) + and self._prev.text.upper() + ) + + if distinct: + distinct = self.expression( + exp.Distinct, + on=self._parse_value(values=False) + if self._match(TokenType.ON) + else None, + ) + + if all_ and distinct: + self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") + + operation_modifiers = [] + while self._curr and self._match_texts(self.OPERATION_MODIFIERS): + operation_modifiers.append(exp.var(self._prev.text.upper())) + + limit = self._parse_limit(top=True) + projections = self._parse_projections() + + this = self.expression( + exp.Select, + kind=kind, + hint=hint, + distinct=distinct, + expressions=projections, + limit=limit, + operation_modifiers=operation_modifiers or None, + ) + this.comments = comments + + into = self._parse_into() + if into: + this.set("into", into) + + if not from_: + from_ = self._parse_from() + + if from_: + this.set("from_", from_) + + this = self._parse_query_modifiers(this) + elif (table or nested) and self._match(TokenType.L_PAREN): + this = self._parse_wrapped_select(table=table) + + # We return early here so that the UNION isn't attached to the subquery by the + # following call to _parse_set_operations, but instead becomes the parent node + self._match_r_paren() + return self._parse_subquery(this, parse_alias=parse_subquery_alias) + elif self._match(TokenType.VALUES, advance=False): + this = self._parse_derived_table_values() + elif from_: + this = exp.select("*").from_(from_.this, copy=False) + elif self._match(TokenType.SUMMARIZE): + table = self._match(TokenType.TABLE) + this = self._parse_select() or self._parse_string() or self._parse_table() + return self.expression(exp.Summarize, this=this, table=table) + elif self._match(TokenType.DESCRIBE): + this = self._parse_describe() + else: + this = None + + return self._parse_set_operations(this) if parse_set_operation else this + + def _parse_recursive_with_search(self) -> t.Optional[exp.RecursiveWithSearch]: + self._match_text_seq("SEARCH") + + kind = ( + self._match_texts(self.RECURSIVE_CTE_SEARCH_KIND) + and self._prev.text.upper() + ) + + if not kind: + return None + + self._match_text_seq("FIRST", "BY") + + return self.expression( + exp.RecursiveWithSearch, + kind=kind, + this=self._parse_id_var(), + expression=self._match_text_seq("SET") and self._parse_id_var(), + using=self._match_text_seq("USING") and self._parse_id_var(), + ) + + def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]: + if not skip_with_token and not self._match(TokenType.WITH): + return None + + comments = self._prev_comments + recursive = self._match(TokenType.RECURSIVE) + + last_comments = None + expressions = [] + while True: + cte = self._parse_cte() + if isinstance(cte, exp.CTE): + expressions.append(cte) + if last_comments: + cte.add_comments(last_comments) + + if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH): + break + else: + self._match(TokenType.WITH) + + last_comments = self._prev_comments + + return self.expression( + exp.With, + comments=comments, + expressions=expressions, + recursive=recursive, + search=self._parse_recursive_with_search(), + ) + + def _parse_cte(self) -> t.Optional[exp.CTE]: + index = self._index + + alias = self._parse_table_alias(self.ID_VAR_TOKENS) + if not alias or not alias.this: + self.raise_error("Expected CTE to have alias") + + key_expressions = ( + self._parse_wrapped_id_vars() + if self._match_text_seq("USING", "KEY") + else None + ) + + if not self._match(TokenType.ALIAS) and not self.OPTIONAL_ALIAS_TOKEN_CTE: + self._retreat(index) + return None + + comments = self._prev_comments + + if self._match_text_seq("NOT", "MATERIALIZED"): + materialized = False + elif self._match_text_seq("MATERIALIZED"): + materialized = True + else: + materialized = None + + cte = self.expression( + exp.CTE, + this=self._parse_wrapped(self._parse_statement), + alias=alias, + materialized=materialized, + key_expressions=key_expressions, + comments=comments, + ) + + values = cte.this + if isinstance(values, exp.Values): + if values.alias: + cte.set("this", exp.select("*").from_(values)) + else: + cte.set( + "this", + exp.select("*").from_(exp.alias_(values, "_values", table=True)), + ) + + return cte + + def _parse_table_alias( + self, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.TableAlias]: + # In some dialects, LIMIT and OFFSET can act as both identifiers and keywords (clauses) + # so this section tries to parse the clause version and if it fails, it treats the token + # as an identifier (alias) + if self._can_parse_limit_or_offset(): + return None + + any_token = self._match(TokenType.ALIAS) + alias = ( + self._parse_id_var( + any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS + ) + or self._parse_string_as_identifier() + ) + + index = self._index + if self._match(TokenType.L_PAREN): + columns = self._parse_csv(self._parse_function_parameter) + self._match_r_paren() if columns else self._retreat(index) + else: + columns = None + + if not alias and not columns: + return None + + table_alias = self.expression(exp.TableAlias, this=alias, columns=columns) + + # We bubble up comments from the Identifier to the TableAlias + if isinstance(alias, exp.Identifier): + table_alias.add_comments(alias.pop_comments()) + + return table_alias + + def _parse_subquery( + self, this: t.Optional[exp.Expression], parse_alias: bool = True + ) -> t.Optional[exp.Subquery]: + if not this: + return None + + return self.expression( + exp.Subquery, + this=this, + pivots=self._parse_pivots(), + alias=self._parse_table_alias() if parse_alias else None, + sample=self._parse_table_sample(), + ) + + def _implicit_unnests_to_explicit(self, this: E) -> E: + from bigframes_vendored.sqlglot.optimizer.normalize_identifiers import ( + normalize_identifiers as _norm, + ) + + refs = { + _norm(this.args["from_"].this.copy(), dialect=self.dialect).alias_or_name + } + for i, join in enumerate(this.args.get("joins") or []): + table = join.this + normalized_table = table.copy() + normalized_table.meta["maybe_column"] = True + normalized_table = _norm(normalized_table, dialect=self.dialect) + + if isinstance(table, exp.Table) and not join.args.get("on"): + if normalized_table.parts[0].name in refs: + table_as_column = table.to_column() + unnest = exp.Unnest(expressions=[table_as_column]) + + # Table.to_column creates a parent Alias node that we want to convert to + # a TableAlias and attach to the Unnest, so it matches the parser's output + if isinstance(table.args.get("alias"), exp.TableAlias): + table_as_column.replace(table_as_column.this) + exp.alias_( + unnest, None, table=[table.args["alias"].this], copy=False + ) + + table.replace(unnest) + + refs.add(normalized_table.alias_or_name) + + return this + + @t.overload + def _parse_query_modifiers(self, this: E) -> E: + ... + + @t.overload + def _parse_query_modifiers(self, this: None) -> None: + ... + + def _parse_query_modifiers(self, this): + if isinstance(this, self.MODIFIABLES): + for join in self._parse_joins(): + this.append("joins", join) + for lateral in iter(self._parse_lateral, None): + this.append("laterals", lateral) + + while True: + if self._match_set(self.QUERY_MODIFIER_PARSERS, advance=False): + modifier_token = self._curr + parser = self.QUERY_MODIFIER_PARSERS[modifier_token.token_type] + key, expression = parser(self) + + if expression: + if this.args.get(key): + self.raise_error( + f"Found multiple '{modifier_token.text.upper()}' clauses", + token=modifier_token, + ) + + this.set(key, expression) + if key == "limit": + offset = expression.args.get("offset") + expression.set("offset", None) + + if offset: + offset = exp.Offset(expression=offset) + this.set("offset", offset) + + limit_by_expressions = expression.expressions + expression.set("expressions", None) + offset.set("expressions", limit_by_expressions) + continue + break + + if self.SUPPORTS_IMPLICIT_UNNEST and this and this.args.get("from_"): + this = self._implicit_unnests_to_explicit(this) + + return this + + def _parse_hint_fallback_to_string(self) -> t.Optional[exp.Hint]: + start = self._curr + while self._curr: + self._advance() + + end = self._tokens[self._index - 1] + return exp.Hint(expressions=[self._find_sql(start, end)]) + + def _parse_hint_function_call(self) -> t.Optional[exp.Expression]: + return self._parse_function_call() + + def _parse_hint_body(self) -> t.Optional[exp.Hint]: + start_index = self._index + should_fallback_to_string = False + + hints = [] + try: + for hint in iter( + lambda: self._parse_csv( + lambda: self._parse_hint_function_call() + or self._parse_var(upper=True), + ), + [], + ): + hints.extend(hint) + except ParseError: + should_fallback_to_string = True + + if should_fallback_to_string or self._curr: + self._retreat(start_index) + return self._parse_hint_fallback_to_string() + + return self.expression(exp.Hint, expressions=hints) + + def _parse_hint(self) -> t.Optional[exp.Hint]: + if self._match(TokenType.HINT) and self._prev_comments: + return exp.maybe_parse( + self._prev_comments[0], into=exp.Hint, dialect=self.dialect + ) + + return None + + def _parse_into(self) -> t.Optional[exp.Into]: + if not self._match(TokenType.INTO): + return None + + temp = self._match(TokenType.TEMPORARY) + unlogged = self._match_text_seq("UNLOGGED") + self._match(TokenType.TABLE) + + return self.expression( + exp.Into, + this=self._parse_table(schema=True), + temporary=temp, + unlogged=unlogged, + ) + + def _parse_from( + self, + joins: bool = False, + skip_from_token: bool = False, + consume_pipe: bool = False, + ) -> t.Optional[exp.From]: + if not skip_from_token and not self._match(TokenType.FROM): + return None + + return self.expression( + exp.From, + comments=self._prev_comments, + this=self._parse_table(joins=joins, consume_pipe=consume_pipe), + ) + + def _parse_match_recognize_measure(self) -> exp.MatchRecognizeMeasure: + return self.expression( + exp.MatchRecognizeMeasure, + window_frame=self._match_texts(("FINAL", "RUNNING")) + and self._prev.text.upper(), + this=self._parse_expression(), + ) + + def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]: + if not self._match(TokenType.MATCH_RECOGNIZE): + return None + + self._match_l_paren() + + partition = self._parse_partition_by() + order = self._parse_order() + + measures = ( + self._parse_csv(self._parse_match_recognize_measure) + if self._match_text_seq("MEASURES") + else None + ) + + if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): + rows = exp.var("ONE ROW PER MATCH") + elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"): + text = "ALL ROWS PER MATCH" + if self._match_text_seq("SHOW", "EMPTY", "MATCHES"): + text += " SHOW EMPTY MATCHES" + elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"): + text += " OMIT EMPTY MATCHES" + elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"): + text += " WITH UNMATCHED ROWS" + rows = exp.var(text) + else: + rows = None + + if self._match_text_seq("AFTER", "MATCH", "SKIP"): + text = "AFTER MATCH SKIP" + if self._match_text_seq("PAST", "LAST", "ROW"): + text += " PAST LAST ROW" + elif self._match_text_seq("TO", "NEXT", "ROW"): + text += " TO NEXT ROW" + elif self._match_text_seq("TO", "FIRST"): + text += f" TO FIRST {self._advance_any().text}" # type: ignore + elif self._match_text_seq("TO", "LAST"): + text += f" TO LAST {self._advance_any().text}" # type: ignore + after = exp.var(text) + else: + after = None + + if self._match_text_seq("PATTERN"): + self._match_l_paren() + + if not self._curr: + self.raise_error("Expecting )", self._curr) + + paren = 1 + start = self._curr + + while self._curr and paren > 0: + if self._curr.token_type == TokenType.L_PAREN: + paren += 1 + if self._curr.token_type == TokenType.R_PAREN: + paren -= 1 + + end = self._prev + self._advance() + + if paren > 0: + self.raise_error("Expecting )", self._curr) + + pattern = exp.var(self._find_sql(start, end)) + else: + pattern = None + + define = ( + self._parse_csv(self._parse_name_as_expression) + if self._match_text_seq("DEFINE") + else None + ) + + self._match_r_paren() + + return self.expression( + exp.MatchRecognize, + partition_by=partition, + order=order, + measures=measures, + rows=rows, + after=after, + pattern=pattern, + define=define, + alias=self._parse_table_alias(), + ) + + def _parse_lateral(self) -> t.Optional[exp.Lateral]: + cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) + if not cross_apply and self._match_pair(TokenType.OUTER, TokenType.APPLY): + cross_apply = False + + if cross_apply is not None: + this = self._parse_select(table=True) + view = None + outer = None + elif self._match(TokenType.LATERAL): + this = self._parse_select(table=True) + view = self._match(TokenType.VIEW) + outer = self._match(TokenType.OUTER) + else: + return None + + if not this: + this = ( + self._parse_unnest() + or self._parse_function() + or self._parse_id_var(any_token=False) + ) + + while self._match(TokenType.DOT): + this = exp.Dot( + this=this, + expression=self._parse_function() + or self._parse_id_var(any_token=False), + ) + + ordinality: t.Optional[bool] = None + + if view: + table = self._parse_id_var(any_token=False) + columns = ( + self._parse_csv(self._parse_id_var) + if self._match(TokenType.ALIAS) + else [] + ) + table_alias: t.Optional[exp.TableAlias] = self.expression( + exp.TableAlias, this=table, columns=columns + ) + elif isinstance(this, (exp.Subquery, exp.Unnest)) and this.alias: + # We move the alias from the lateral's child node to the lateral itself + table_alias = this.args["alias"].pop() + else: + ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) + table_alias = self._parse_table_alias() + + return self.expression( + exp.Lateral, + this=this, + view=view, + outer=outer, + alias=table_alias, + cross_apply=cross_apply, + ordinality=ordinality, + ) + + def _parse_stream(self) -> t.Optional[exp.Stream]: + index = self._index + if self._match_text_seq("STREAM"): + this = self._try_parse(self._parse_table) + if this: + return self.expression(exp.Stream, this=this) + + self._retreat(index) + return None + + def _parse_join_parts( + self, + ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: + return ( + self._match_set(self.JOIN_METHODS) and self._prev, + self._match_set(self.JOIN_SIDES) and self._prev, + self._match_set(self.JOIN_KINDS) and self._prev, + ) + + def _parse_using_identifiers(self) -> t.List[exp.Expression]: + def _parse_column_as_identifier() -> t.Optional[exp.Expression]: + this = self._parse_column() + if isinstance(this, exp.Column): + return this.this + return this + + return self._parse_wrapped_csv(_parse_column_as_identifier, optional=True) + + def _parse_join( + self, skip_join_token: bool = False, parse_bracket: bool = False + ) -> t.Optional[exp.Join]: + if self._match(TokenType.COMMA): + table = self._try_parse(self._parse_table) + cross_join = self.expression(exp.Join, this=table) if table else None + + if cross_join and self.JOINS_HAVE_EQUAL_PRECEDENCE: + cross_join.set("kind", "CROSS") + + return cross_join + + index = self._index + method, side, kind = self._parse_join_parts() + hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None + join = self._match(TokenType.JOIN) or ( + kind and kind.token_type == TokenType.STRAIGHT_JOIN + ) + join_comments = self._prev_comments + + if not skip_join_token and not join: + self._retreat(index) + kind = None + method = None + side = None + + outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False) + cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY, False) + + if not skip_join_token and not join and not outer_apply and not cross_apply: + return None + + kwargs: t.Dict[str, t.Any] = { + "this": self._parse_table(parse_bracket=parse_bracket) + } + if kind and kind.token_type == TokenType.ARRAY and self._match(TokenType.COMMA): + kwargs["expressions"] = self._parse_csv( + lambda: self._parse_table(parse_bracket=parse_bracket) + ) + + if method: + kwargs["method"] = method.text.upper() + if side: + kwargs["side"] = side.text.upper() + if kind: + kwargs["kind"] = kind.text.upper() + if hint: + kwargs["hint"] = hint + + if self._match(TokenType.MATCH_CONDITION): + kwargs["match_condition"] = self._parse_wrapped(self._parse_comparison) + + if self._match(TokenType.ON): + kwargs["on"] = self._parse_disjunction() + elif self._match(TokenType.USING): + kwargs["using"] = self._parse_using_identifiers() + elif ( + not method + and not (outer_apply or cross_apply) + and not isinstance(kwargs["this"], exp.Unnest) + and not (kind and kind.token_type in (TokenType.CROSS, TokenType.ARRAY)) + ): + index = self._index + joins: t.Optional[list] = list(self._parse_joins()) + + if joins and self._match(TokenType.ON): + kwargs["on"] = self._parse_disjunction() + elif joins and self._match(TokenType.USING): + kwargs["using"] = self._parse_using_identifiers() + else: + joins = None + self._retreat(index) + + kwargs["this"].set("joins", joins if joins else None) + + kwargs["pivots"] = self._parse_pivots() + + comments = [ + c for token in (method, side, kind) if token for c in token.comments + ] + comments = (join_comments or []) + comments + + if ( + self.ADD_JOIN_ON_TRUE + and not kwargs.get("on") + and not kwargs.get("using") + and not kwargs.get("method") + and kwargs.get("kind") in (None, "INNER", "OUTER") + ): + kwargs["on"] = exp.true() + + return self.expression(exp.Join, comments=comments, **kwargs) + + def _parse_opclass(self) -> t.Optional[exp.Expression]: + this = self._parse_disjunction() + + if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False): + return this + + if not self._match_set(self.OPTYPE_FOLLOW_TOKENS, advance=False): + return self.expression( + exp.Opclass, this=this, expression=self._parse_table_parts() + ) + + return this + + def _parse_index_params(self) -> exp.IndexParameters: + using = ( + self._parse_var(any_token=True) if self._match(TokenType.USING) else None + ) + + if self._match(TokenType.L_PAREN, advance=False): + columns = self._parse_wrapped_csv(self._parse_with_operator) + else: + columns = None + + include = ( + self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None + ) + partition_by = self._parse_partition_by() + with_storage = self._match(TokenType.WITH) and self._parse_wrapped_properties() + tablespace = ( + self._parse_var(any_token=True) + if self._match_text_seq("USING", "INDEX", "TABLESPACE") + else None + ) + where = self._parse_where() + + on = self._parse_field() if self._match(TokenType.ON) else None + + return self.expression( + exp.IndexParameters, + using=using, + columns=columns, + include=include, + partition_by=partition_by, + where=where, + with_storage=with_storage, + tablespace=tablespace, + on=on, + ) + + def _parse_index( + self, index: t.Optional[exp.Expression] = None, anonymous: bool = False + ) -> t.Optional[exp.Index]: + if index or anonymous: + unique = None + primary = None + amp = None + + self._match(TokenType.ON) + self._match(TokenType.TABLE) # hive + table = self._parse_table_parts(schema=True) + else: + unique = self._match(TokenType.UNIQUE) + primary = self._match_text_seq("PRIMARY") + amp = self._match_text_seq("AMP") + + if not self._match(TokenType.INDEX): + return None + + index = self._parse_id_var() + table = None + + params = self._parse_index_params() + + return self.expression( + exp.Index, + this=index, + table=table, + unique=unique, + primary=primary, + amp=amp, + params=params, + ) + + def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]: + hints: t.List[exp.Expression] = [] + if self._match_pair(TokenType.WITH, TokenType.L_PAREN): + # https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16 + hints.append( + self.expression( + exp.WithTableHint, + expressions=self._parse_csv( + lambda: self._parse_function() + or self._parse_var(any_token=True) + ), + ) + ) + self._match_r_paren() + else: + # https://dev.mysql.com/doc/refman/8.0/en/index-hints.html + while self._match_set(self.TABLE_INDEX_HINT_TOKENS): + hint = exp.IndexTableHint(this=self._prev.text.upper()) + + self._match_set((TokenType.INDEX, TokenType.KEY)) + if self._match(TokenType.FOR): + hint.set("target", self._advance_any() and self._prev.text.upper()) + + hint.set("expressions", self._parse_wrapped_id_vars()) + hints.append(hint) + + return hints or None + + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: + return ( + (not schema and self._parse_function(optional_parens=False)) + or self._parse_id_var(any_token=False) + or self._parse_string_as_identifier() + or self._parse_placeholder() + ) + + def _parse_table_parts( + self, + schema: bool = False, + is_db_reference: bool = False, + wildcard: bool = False, + ) -> exp.Table: + catalog = None + db = None + table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema) + + while self._match(TokenType.DOT): + if catalog: + # This allows nesting the table in arbitrarily many dot expressions if needed + table = self.expression( + exp.Dot, + this=table, + expression=self._parse_table_part(schema=schema), + ) + else: + catalog = db + db = table + # "" used for tsql FROM a..b case + table = self._parse_table_part(schema=schema) or "" + + if ( + wildcard + and self._is_connected() + and (isinstance(table, exp.Identifier) or not table) + and self._match(TokenType.STAR) + ): + if isinstance(table, exp.Identifier): + table.args["this"] += "*" + else: + table = exp.Identifier(this="*") + + # We bubble up comments from the Identifier to the Table + comments = table.pop_comments() if isinstance(table, exp.Expression) else None + + if is_db_reference: + catalog = db + db = table + table = None + + if not table and not is_db_reference: + self.raise_error(f"Expected table name but got {self._curr}") + if not db and is_db_reference: + self.raise_error(f"Expected database name but got {self._curr}") + + table = self.expression( + exp.Table, + comments=comments, + this=table, + db=db, + catalog=catalog, + ) + + changes = self._parse_changes() + if changes: + table.set("changes", changes) + + at_before = self._parse_historical_data() + if at_before: + table.set("when", at_before) + + pivots = self._parse_pivots() + if pivots: + table.set("pivots", pivots) + + return table + + def _parse_table( + self, + schema: bool = False, + joins: bool = False, + alias_tokens: t.Optional[t.Collection[TokenType]] = None, + parse_bracket: bool = False, + is_db_reference: bool = False, + parse_partition: bool = False, + consume_pipe: bool = False, + ) -> t.Optional[exp.Expression]: + stream = self._parse_stream() + if stream: + return stream + + lateral = self._parse_lateral() + if lateral: + return lateral + + unnest = self._parse_unnest() + if unnest: + return unnest + + values = self._parse_derived_table_values() + if values: + return values + + subquery = self._parse_select(table=True, consume_pipe=consume_pipe) + if subquery: + if not subquery.args.get("pivots"): + subquery.set("pivots", self._parse_pivots()) + return subquery + + bracket = parse_bracket and self._parse_bracket(None) + bracket = self.expression(exp.Table, this=bracket) if bracket else None + + rows_from = self._match_text_seq("ROWS", "FROM") and self._parse_wrapped_csv( + self._parse_table + ) + rows_from = ( + self.expression(exp.Table, rows_from=rows_from) if rows_from else None + ) + + only = self._match(TokenType.ONLY) + + this = t.cast( + exp.Expression, + bracket + or rows_from + or self._parse_bracket( + self._parse_table_parts(schema=schema, is_db_reference=is_db_reference) + ), + ) + + if only: + this.set("only", only) + + # Postgres supports a wildcard (table) suffix operator, which is a no-op in this context + self._match_text_seq("*") + + parse_partition = parse_partition or self.SUPPORTS_PARTITION_SELECTION + if parse_partition and self._match(TokenType.PARTITION, advance=False): + this.set("partition", self._parse_partition()) + + if schema: + return self._parse_schema(this=this) + + version = self._parse_version() + + if version: + this.set("version", version) + + if self.dialect.ALIAS_POST_TABLESAMPLE: + this.set("sample", self._parse_table_sample()) + + alias = self._parse_table_alias( + alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS + ) + if alias: + this.set("alias", alias) + + if self._match(TokenType.INDEXED_BY): + this.set("indexed", self._parse_table_parts()) + elif self._match_text_seq("NOT", "INDEXED"): + this.set("indexed", False) + + if isinstance(this, exp.Table) and self._match_text_seq("AT"): + return self.expression( + exp.AtIndex, + this=this.to_column(copy=False), + expression=self._parse_id_var(), + ) + + this.set("hints", self._parse_table_hints()) + + if not this.args.get("pivots"): + this.set("pivots", self._parse_pivots()) + + if not self.dialect.ALIAS_POST_TABLESAMPLE: + this.set("sample", self._parse_table_sample()) + + if joins: + for join in self._parse_joins(): + this.append("joins", join) + + if self._match_pair(TokenType.WITH, TokenType.ORDINALITY): + this.set("ordinality", True) + this.set("alias", self._parse_table_alias()) + + return this + + def _parse_version(self) -> t.Optional[exp.Version]: + if self._match(TokenType.TIMESTAMP_SNAPSHOT): + this = "TIMESTAMP" + elif self._match(TokenType.VERSION_SNAPSHOT): + this = "VERSION" + else: + return None + + if self._match_set((TokenType.FROM, TokenType.BETWEEN)): + kind = self._prev.text.upper() + start = self._parse_bitwise() + self._match_texts(("TO", "AND")) + end = self._parse_bitwise() + expression: t.Optional[exp.Expression] = self.expression( + exp.Tuple, expressions=[start, end] + ) + elif self._match_text_seq("CONTAINED", "IN"): + kind = "CONTAINED IN" + expression = self.expression( + exp.Tuple, expressions=self._parse_wrapped_csv(self._parse_bitwise) + ) + elif self._match(TokenType.ALL): + kind = "ALL" + expression = None + else: + self._match_text_seq("AS", "OF") + kind = "AS OF" + expression = self._parse_type() + + return self.expression(exp.Version, this=this, expression=expression, kind=kind) + + def _parse_historical_data(self) -> t.Optional[exp.HistoricalData]: + # https://docs.snowflake.com/en/sql-reference/constructs/at-before + index = self._index + historical_data = None + if self._match_texts(self.HISTORICAL_DATA_PREFIX): + this = self._prev.text.upper() + kind = ( + self._match(TokenType.L_PAREN) + and self._match_texts(self.HISTORICAL_DATA_KIND) + and self._prev.text.upper() + ) + expression = self._match(TokenType.FARROW) and self._parse_bitwise() + + if expression: + self._match_r_paren() + historical_data = self.expression( + exp.HistoricalData, this=this, kind=kind, expression=expression + ) + else: + self._retreat(index) + + return historical_data + + def _parse_changes(self) -> t.Optional[exp.Changes]: + if not self._match_text_seq("CHANGES", "(", "INFORMATION", "=>"): + return None + + information = self._parse_var(any_token=True) + self._match_r_paren() + + return self.expression( + exp.Changes, + information=information, + at_before=self._parse_historical_data(), + end=self._parse_historical_data(), + ) + + def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: + if not self._match_pair(TokenType.UNNEST, TokenType.L_PAREN, advance=False): + return None + + self._advance() + + expressions = self._parse_wrapped_csv(self._parse_equality) + offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) + + alias = self._parse_table_alias() if with_alias else None + + if alias: + if self.dialect.UNNEST_COLUMN_ONLY: + if alias.args.get("columns"): + self.raise_error("Unexpected extra column alias in unnest.") + + alias.set("columns", [alias.this]) + alias.set("this", None) + + columns = alias.args.get("columns") or [] + if offset and len(expressions) < len(columns): + offset = columns.pop() + + if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET): + self._match(TokenType.ALIAS) + offset = self._parse_id_var( + any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS + ) or exp.to_identifier("offset") + + return self.expression( + exp.Unnest, expressions=expressions, alias=alias, offset=offset + ) + + def _parse_derived_table_values(self) -> t.Optional[exp.Values]: + is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) + if not is_derived and not ( + # ClickHouse's `FORMAT Values` is equivalent to `VALUES` + self._match_text_seq("VALUES") + or self._match_text_seq("FORMAT", "VALUES") + ): + return None + + expressions = self._parse_csv(self._parse_value) + alias = self._parse_table_alias() + + if is_derived: + self._match_r_paren() + + return self.expression( + exp.Values, + expressions=expressions, + alias=alias or self._parse_table_alias(), + ) + + def _parse_table_sample( + self, as_modifier: bool = False + ) -> t.Optional[exp.TableSample]: + if not self._match(TokenType.TABLE_SAMPLE) and not ( + as_modifier and self._match_text_seq("USING", "SAMPLE") + ): + return None + + bucket_numerator = None + bucket_denominator = None + bucket_field = None + percent = None + size = None + seed = None + + method = self._parse_var(tokens=(TokenType.ROW,), upper=True) + matched_l_paren = self._match(TokenType.L_PAREN) + + if self.TABLESAMPLE_CSV: + num = None + expressions = self._parse_csv(self._parse_primary) + else: + expressions = None + num = ( + self._parse_factor() + if self._match(TokenType.NUMBER, advance=False) + else self._parse_primary() or self._parse_placeholder() + ) + + if self._match_text_seq("BUCKET"): + bucket_numerator = self._parse_number() + self._match_text_seq("OUT", "OF") + bucket_denominator = bucket_denominator = self._parse_number() + self._match(TokenType.ON) + bucket_field = self._parse_field() + elif self._match_set((TokenType.PERCENT, TokenType.MOD)): + percent = num + elif ( + self._match(TokenType.ROWS) or not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT + ): + size = num + else: + percent = num + + if matched_l_paren: + self._match_r_paren() + + if self._match(TokenType.L_PAREN): + method = self._parse_var(upper=True) + seed = self._match(TokenType.COMMA) and self._parse_number() + self._match_r_paren() + elif self._match_texts(("SEED", "REPEATABLE")): + seed = self._parse_wrapped(self._parse_number) + + if not method and self.DEFAULT_SAMPLING_METHOD: + method = exp.var(self.DEFAULT_SAMPLING_METHOD) + + return self.expression( + exp.TableSample, + expressions=expressions, + method=method, + bucket_numerator=bucket_numerator, + bucket_denominator=bucket_denominator, + bucket_field=bucket_field, + percent=percent, + size=size, + seed=seed, + ) + + def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]: + return list(iter(self._parse_pivot, None)) or None + + def _parse_joins(self) -> t.Iterator[exp.Join]: + return iter(self._parse_join, None) + + def _parse_unpivot_columns(self) -> t.Optional[exp.UnpivotColumns]: + if not self._match(TokenType.INTO): + return None + + return self.expression( + exp.UnpivotColumns, + this=self._match_text_seq("NAME") and self._parse_column(), + expressions=self._match_text_seq("VALUE") + and self._parse_csv(self._parse_column), + ) + + # https://duckdb.org/docs/sql/statements/pivot + def _parse_simplified_pivot(self, is_unpivot: t.Optional[bool] = None) -> exp.Pivot: + def _parse_on() -> t.Optional[exp.Expression]: + this = self._parse_bitwise() + + if self._match(TokenType.IN): + # PIVOT ... ON col IN (row_val1, row_val2) + return self._parse_in(this) + if self._match(TokenType.ALIAS, advance=False): + # UNPIVOT ... ON (col1, col2, col3) AS row_val + return self._parse_alias(this) + + return this + + this = self._parse_table() + expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on) + into = self._parse_unpivot_columns() + using = self._match(TokenType.USING) and self._parse_csv( + lambda: self._parse_alias(self._parse_column()) + ) + group = self._parse_group() + + return self.expression( + exp.Pivot, + this=this, + expressions=expressions, + using=using, + group=group, + unpivot=is_unpivot, + into=into, + ) + + def _parse_pivot_in(self) -> exp.In: + def _parse_aliased_expression() -> t.Optional[exp.Expression]: + this = self._parse_select_or_expression() + + self._match(TokenType.ALIAS) + alias = self._parse_bitwise() + if alias: + if isinstance(alias, exp.Column) and not alias.db: + alias = alias.this + return self.expression(exp.PivotAlias, this=this, alias=alias) + + return this + + value = self._parse_column() + + if not self._match(TokenType.IN): + self.raise_error("Expecting IN") + + if self._match(TokenType.L_PAREN): + if self._match(TokenType.ANY): + exprs: t.List[exp.Expression] = ensure_list( + exp.PivotAny(this=self._parse_order()) + ) + else: + exprs = self._parse_csv(_parse_aliased_expression) + self._match_r_paren() + return self.expression(exp.In, this=value, expressions=exprs) + + return self.expression(exp.In, this=value, field=self._parse_id_var()) + + def _parse_pivot_aggregation(self) -> t.Optional[exp.Expression]: + func = self._parse_function() + if not func: + if self._prev and self._prev.token_type == TokenType.COMMA: + return None + self.raise_error("Expecting an aggregation function in PIVOT") + + return self._parse_alias(func) + + def _parse_pivot(self) -> t.Optional[exp.Pivot]: + index = self._index + include_nulls = None + + if self._match(TokenType.PIVOT): + unpivot = False + elif self._match(TokenType.UNPIVOT): + unpivot = True + + # https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax + if self._match_text_seq("INCLUDE", "NULLS"): + include_nulls = True + elif self._match_text_seq("EXCLUDE", "NULLS"): + include_nulls = False + else: + return None + + expressions = [] + + if not self._match(TokenType.L_PAREN): + self._retreat(index) + return None + + if unpivot: + expressions = self._parse_csv(self._parse_column) + else: + expressions = self._parse_csv(self._parse_pivot_aggregation) + + if not expressions: + self.raise_error("Failed to parse PIVOT's aggregation list") + + if not self._match(TokenType.FOR): + self.raise_error("Expecting FOR") + + fields = [] + while True: + field = self._try_parse(self._parse_pivot_in) + if not field: + break + fields.append(field) + + default_on_null = self._match_text_seq( + "DEFAULT", "ON", "NULL" + ) and self._parse_wrapped(self._parse_bitwise) + + group = self._parse_group() + + self._match_r_paren() + + pivot = self.expression( + exp.Pivot, + expressions=expressions, + fields=fields, + unpivot=unpivot, + include_nulls=include_nulls, + default_on_null=default_on_null, + group=group, + ) + + if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False): + pivot.set("alias", self._parse_table_alias()) + + if not unpivot: + names = self._pivot_column_names( + t.cast(t.List[exp.Expression], expressions) + ) + + columns: t.List[exp.Expression] = [] + all_fields = [] + for pivot_field in pivot.fields: + pivot_field_expressions = pivot_field.expressions + + # The `PivotAny` expression corresponds to `ANY ORDER BY `; we can't infer in this case. + if isinstance(seq_get(pivot_field_expressions, 0), exp.PivotAny): + continue + + all_fields.append( + [ + fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name + for fld in pivot_field_expressions + ] + ) + + if all_fields: + if names: + all_fields.append(names) + + # Generate all possible combinations of the pivot columns + # e.g PIVOT(sum(...) as total FOR year IN (2000, 2010) FOR country IN ('NL', 'US')) + # generates the product between [[2000, 2010], ['NL', 'US'], ['total']] + for fld_parts_tuple in itertools.product(*all_fields): + fld_parts = list(fld_parts_tuple) + + if names and self.PREFIXED_PIVOT_COLUMNS: + # Move the "name" to the front of the list + fld_parts.insert(0, fld_parts.pop(-1)) + + columns.append(exp.to_identifier("_".join(fld_parts))) + + pivot.set("columns", columns) + + return pivot + + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: + return [agg.alias for agg in aggregations if agg.alias] + + def _parse_prewhere( + self, skip_where_token: bool = False + ) -> t.Optional[exp.PreWhere]: + if not skip_where_token and not self._match(TokenType.PREWHERE): + return None + + return self.expression( + exp.PreWhere, comments=self._prev_comments, this=self._parse_disjunction() + ) + + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]: + if not skip_where_token and not self._match(TokenType.WHERE): + return None + + return self.expression( + exp.Where, comments=self._prev_comments, this=self._parse_disjunction() + ) + + def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Group]: + if not skip_group_by_token and not self._match(TokenType.GROUP_BY): + return None + comments = self._prev_comments + + elements: t.Dict[str, t.Any] = defaultdict(list) + + if self._match(TokenType.ALL): + elements["all"] = True + elif self._match(TokenType.DISTINCT): + elements["all"] = False + + if self._match_set(self.QUERY_MODIFIER_TOKENS, advance=False): + return self.expression(exp.Group, comments=comments, **elements) # type: ignore + + while True: + index = self._index + + elements["expressions"].extend( + self._parse_csv( + lambda: None + if self._match_set( + (TokenType.CUBE, TokenType.ROLLUP), advance=False + ) + else self._parse_disjunction() + ) + ) + + before_with_index = self._index + with_prefix = self._match(TokenType.WITH) + + if cube_or_rollup := self._parse_cube_or_rollup(with_prefix=with_prefix): + key = "rollup" if isinstance(cube_or_rollup, exp.Rollup) else "cube" + elements[key].append(cube_or_rollup) + elif grouping_sets := self._parse_grouping_sets(): + elements["grouping_sets"].append(grouping_sets) + elif self._match_text_seq("TOTALS"): + elements["totals"] = True # type: ignore + + if before_with_index <= self._index <= before_with_index + 1: + self._retreat(before_with_index) + break + + if index == self._index: + break + + return self.expression(exp.Group, comments=comments, **elements) # type: ignore + + def _parse_cube_or_rollup( + self, with_prefix: bool = False + ) -> t.Optional[exp.Cube | exp.Rollup]: + if self._match(TokenType.CUBE): + kind: t.Type[exp.Cube | exp.Rollup] = exp.Cube + elif self._match(TokenType.ROLLUP): + kind = exp.Rollup + else: + return None + + return self.expression( + kind, + expressions=[] + if with_prefix + else self._parse_wrapped_csv(self._parse_bitwise), + ) + + def _parse_grouping_sets(self) -> t.Optional[exp.GroupingSets]: + if self._match(TokenType.GROUPING_SETS): + return self.expression( + exp.GroupingSets, + expressions=self._parse_wrapped_csv(self._parse_grouping_set), + ) + return None + + def _parse_grouping_set(self) -> t.Optional[exp.Expression]: + return ( + self._parse_grouping_sets() + or self._parse_cube_or_rollup() + or self._parse_bitwise() + ) + + def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Having]: + if not skip_having_token and not self._match(TokenType.HAVING): + return None + return self.expression( + exp.Having, comments=self._prev_comments, this=self._parse_disjunction() + ) + + def _parse_qualify(self) -> t.Optional[exp.Qualify]: + if not self._match(TokenType.QUALIFY): + return None + return self.expression(exp.Qualify, this=self._parse_disjunction()) + + def _parse_connect_with_prior(self) -> t.Optional[exp.Expression]: + self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression( + exp.Prior, this=self._parse_bitwise() + ) + connect = self._parse_disjunction() + self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR") + return connect + + def _parse_connect(self, skip_start_token: bool = False) -> t.Optional[exp.Connect]: + if skip_start_token: + start = None + elif self._match(TokenType.START_WITH): + start = self._parse_disjunction() + else: + return None + + self._match(TokenType.CONNECT_BY) + nocycle = self._match_text_seq("NOCYCLE") + connect = self._parse_connect_with_prior() + + if not start and self._match(TokenType.START_WITH): + start = self._parse_disjunction() + + return self.expression( + exp.Connect, start=start, connect=connect, nocycle=nocycle + ) + + def _parse_name_as_expression(self) -> t.Optional[exp.Expression]: + this = self._parse_id_var(any_token=True) + if self._match(TokenType.ALIAS): + this = self.expression( + exp.Alias, alias=this, this=self._parse_disjunction() + ) + return this + + def _parse_interpolate(self) -> t.Optional[t.List[exp.Expression]]: + if self._match_text_seq("INTERPOLATE"): + return self._parse_wrapped_csv(self._parse_name_as_expression) + return None + + def _parse_order( + self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False + ) -> t.Optional[exp.Expression]: + siblings = None + if not skip_order_token and not self._match(TokenType.ORDER_BY): + if not self._match(TokenType.ORDER_SIBLINGS_BY): + return this + + siblings = True + + return self.expression( + exp.Order, + comments=self._prev_comments, + this=this, + expressions=self._parse_csv(self._parse_ordered), + siblings=siblings, + ) + + def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]: + if not self._match(token): + return None + return self.expression( + exp_class, expressions=self._parse_csv(self._parse_ordered) + ) + + def _parse_ordered( + self, parse_method: t.Optional[t.Callable] = None + ) -> t.Optional[exp.Ordered]: + this = parse_method() if parse_method else self._parse_disjunction() + if not this: + return None + + if this.name.upper() == "ALL" and self.dialect.SUPPORTS_ORDER_BY_ALL: + this = exp.var("ALL") + + asc = self._match(TokenType.ASC) + desc = self._match(TokenType.DESC) or (asc and False) + + is_nulls_first = self._match_text_seq("NULLS", "FIRST") + is_nulls_last = self._match_text_seq("NULLS", "LAST") + + nulls_first = is_nulls_first or False + explicitly_null_ordered = is_nulls_first or is_nulls_last + + if ( + not explicitly_null_ordered + and ( + (not desc and self.dialect.NULL_ORDERING == "nulls_are_small") + or (desc and self.dialect.NULL_ORDERING != "nulls_are_small") + ) + and self.dialect.NULL_ORDERING != "nulls_are_last" + ): + nulls_first = True + + if self._match_text_seq("WITH", "FILL"): + with_fill = self.expression( + exp.WithFill, + from_=self._match(TokenType.FROM) and self._parse_bitwise(), + to=self._match_text_seq("TO") and self._parse_bitwise(), + step=self._match_text_seq("STEP") and self._parse_bitwise(), + interpolate=self._parse_interpolate(), + ) + else: + with_fill = None + + return self.expression( + exp.Ordered, + this=this, + desc=desc, + nulls_first=nulls_first, + with_fill=with_fill, + ) + + def _parse_limit_options(self) -> t.Optional[exp.LimitOptions]: + percent = self._match_set((TokenType.PERCENT, TokenType.MOD)) + rows = self._match_set((TokenType.ROW, TokenType.ROWS)) + self._match_text_seq("ONLY") + with_ties = self._match_text_seq("WITH", "TIES") + + if not (percent or rows or with_ties): + return None + + return self.expression( + exp.LimitOptions, percent=percent, rows=rows, with_ties=with_ties + ) + + def _parse_limit( + self, + this: t.Optional[exp.Expression] = None, + top: bool = False, + skip_limit_token: bool = False, + ) -> t.Optional[exp.Expression]: + if skip_limit_token or self._match(TokenType.TOP if top else TokenType.LIMIT): + comments = self._prev_comments + if top: + limit_paren = self._match(TokenType.L_PAREN) + expression = self._parse_term() if limit_paren else self._parse_number() + + if limit_paren: + self._match_r_paren() + + else: + # Parsing LIMIT x% (i.e x PERCENT) as a term leads to an error, since + # we try to build an exp.Mod expr. For that matter, we backtrack and instead + # consume the factor plus parse the percentage separately + index = self._index + expression = self._try_parse(self._parse_term) + if isinstance(expression, exp.Mod): + self._retreat(index) + expression = self._parse_factor() + elif not expression: + expression = self._parse_factor() + limit_options = self._parse_limit_options() + + if self._match(TokenType.COMMA): + offset = expression + expression = self._parse_term() + else: + offset = None + + limit_exp = self.expression( + exp.Limit, + this=this, + expression=expression, + offset=offset, + comments=comments, + limit_options=limit_options, + expressions=self._parse_limit_by(), + ) + + return limit_exp + + if self._match(TokenType.FETCH): + direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) + direction = self._prev.text.upper() if direction else "FIRST" + + count = self._parse_field(tokens=self.FETCH_TOKENS) + + return self.expression( + exp.Fetch, + direction=direction, + count=count, + limit_options=self._parse_limit_options(), + ) + + return this + + def _parse_offset( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + if not self._match(TokenType.OFFSET): + return this + + count = self._parse_term() + self._match_set((TokenType.ROW, TokenType.ROWS)) + + return self.expression( + exp.Offset, this=this, expression=count, expressions=self._parse_limit_by() + ) + + def _can_parse_limit_or_offset(self) -> bool: + if not self._match_set(self.AMBIGUOUS_ALIAS_TOKENS, advance=False): + return False + + index = self._index + result = bool( + self._try_parse(self._parse_limit, retreat=True) + or self._try_parse(self._parse_offset, retreat=True) + ) + self._retreat(index) + return result + + def _parse_limit_by(self) -> t.Optional[t.List[exp.Expression]]: + return self._match_text_seq("BY") and self._parse_csv(self._parse_bitwise) + + def _parse_locks(self) -> t.List[exp.Lock]: + locks = [] + while True: + update, key = None, None + if self._match_text_seq("FOR", "UPDATE"): + update = True + elif self._match_text_seq("FOR", "SHARE") or self._match_text_seq( + "LOCK", "IN", "SHARE", "MODE" + ): + update = False + elif self._match_text_seq("FOR", "KEY", "SHARE"): + update, key = False, True + elif self._match_text_seq("FOR", "NO", "KEY", "UPDATE"): + update, key = True, True + else: + break + + expressions = None + if self._match_text_seq("OF"): + expressions = self._parse_csv(lambda: self._parse_table(schema=True)) + + wait: t.Optional[bool | exp.Expression] = None + if self._match_text_seq("NOWAIT"): + wait = True + elif self._match_text_seq("WAIT"): + wait = self._parse_primary() + elif self._match_text_seq("SKIP", "LOCKED"): + wait = False + + locks.append( + self.expression( + exp.Lock, update=update, expressions=expressions, wait=wait, key=key + ) + ) + + return locks + + def parse_set_operation( + self, this: t.Optional[exp.Expression], consume_pipe: bool = False + ) -> t.Optional[exp.Expression]: + start = self._index + _, side_token, kind_token = self._parse_join_parts() + + side = side_token.text if side_token else None + kind = kind_token.text if kind_token else None + + if not self._match_set(self.SET_OPERATIONS): + self._retreat(start) + return None + + token_type = self._prev.token_type + + if token_type == TokenType.UNION: + operation: t.Type[exp.SetOperation] = exp.Union + elif token_type == TokenType.EXCEPT: + operation = exp.Except + else: + operation = exp.Intersect + + comments = self._prev.comments + + if self._match(TokenType.DISTINCT): + distinct: t.Optional[bool] = True + elif self._match(TokenType.ALL): + distinct = False + else: + distinct = self.dialect.SET_OP_DISTINCT_BY_DEFAULT[operation] + if distinct is None: + self.raise_error(f"Expected DISTINCT or ALL for {operation.__name__}") + + by_name = self._match_text_seq("BY", "NAME") or self._match_text_seq( + "STRICT", "CORRESPONDING" + ) + if self._match_text_seq("CORRESPONDING"): + by_name = True + if not side and not kind: + kind = "INNER" + + on_column_list = None + if by_name and self._match_texts(("ON", "BY")): + on_column_list = self._parse_wrapped_csv(self._parse_column) + + expression = self._parse_select( + nested=True, parse_set_operation=False, consume_pipe=consume_pipe + ) + + return self.expression( + operation, + comments=comments, + this=this, + distinct=distinct, + by_name=by_name, + expression=expression, + side=side, + kind=kind, + on=on_column_list, + ) + + def _parse_set_operations( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + while this: + setop = self.parse_set_operation(this) + if not setop: + break + this = setop + + if isinstance(this, exp.SetOperation) and self.MODIFIERS_ATTACHED_TO_SET_OP: + expression = this.expression + + if expression: + for arg in self.SET_OP_MODIFIERS: + expr = expression.args.get(arg) + if expr: + this.set(arg, expr.pop()) + + return this + + def _parse_expression(self) -> t.Optional[exp.Expression]: + return self._parse_alias(self._parse_assignment()) + + def _parse_assignment(self) -> t.Optional[exp.Expression]: + this = self._parse_disjunction() + if not this and self._next and self._next.token_type in self.ASSIGNMENT: + # This allows us to parse := + this = exp.column( + t.cast(str, self._advance_any(ignore_reserved=True) and self._prev.text) + ) + + while self._match_set(self.ASSIGNMENT): + if isinstance(this, exp.Column) and len(this.parts) == 1: + this = this.this + + this = self.expression( + self.ASSIGNMENT[self._prev.token_type], + this=this, + comments=self._prev_comments, + expression=self._parse_assignment(), + ) + + return this + + def _parse_disjunction(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_conjunction, self.DISJUNCTION) + + def _parse_conjunction(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_equality, self.CONJUNCTION) + + def _parse_equality(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_comparison, self.EQUALITY) + + def _parse_comparison(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_range, self.COMPARISON) + + def _parse_range( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + this = this or self._parse_bitwise() + negate = self._match(TokenType.NOT) + + if self._match_set(self.RANGE_PARSERS): + expression = self.RANGE_PARSERS[self._prev.token_type](self, this) + if not expression: + return this + + this = expression + elif self._match(TokenType.ISNULL) or (negate and self._match(TokenType.NULL)): + this = self.expression(exp.Is, this=this, expression=exp.Null()) + + # Postgres supports ISNULL and NOTNULL for conditions. + # https://blog.andreiavram.ro/postgresql-null-composite-type/ + if self._match(TokenType.NOTNULL): + this = self.expression(exp.Is, this=this, expression=exp.Null()) + this = self.expression(exp.Not, this=this) + + if negate: + this = self._negate_range(this) + + if self._match(TokenType.IS): + this = self._parse_is(this) + + return this + + def _negate_range( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + if not this: + return this + + return self.expression(exp.Not, this=this) + + def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + index = self._index - 1 + negate = self._match(TokenType.NOT) + + if self._match_text_seq("DISTINCT", "FROM"): + klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ + return self.expression(klass, this=this, expression=self._parse_bitwise()) + + if self._match(TokenType.JSON): + kind = ( + self._match_texts(self.IS_JSON_PREDICATE_KIND) + and self._prev.text.upper() + ) + + if self._match_text_seq("WITH"): + _with = True + elif self._match_text_seq("WITHOUT"): + _with = False + else: + _with = None + + unique = self._match(TokenType.UNIQUE) + self._match_text_seq("KEYS") + expression: t.Optional[exp.Expression] = self.expression( + exp.JSON, + this=kind, + with_=_with, + unique=unique, + ) + else: + expression = self._parse_null() or self._parse_bitwise() + if not expression: + self._retreat(index) + return None + + this = self.expression(exp.Is, this=this, expression=expression) + this = self.expression(exp.Not, this=this) if negate else this + return self._parse_column_ops(this) + + def _parse_in( + self, this: t.Optional[exp.Expression], alias: bool = False + ) -> exp.In: + unnest = self._parse_unnest(with_alias=False) + if unnest: + this = self.expression(exp.In, this=this, unnest=unnest) + elif self._match_set((TokenType.L_PAREN, TokenType.L_BRACKET)): + matched_l_paren = self._prev.token_type == TokenType.L_PAREN + expressions = self._parse_csv( + lambda: self._parse_select_or_expression(alias=alias) + ) + + if len(expressions) == 1 and isinstance(query := expressions[0], exp.Query): + this = self.expression( + exp.In, + this=this, + query=self._parse_query_modifiers(query).subquery(copy=False), + ) + else: + this = self.expression(exp.In, this=this, expressions=expressions) + + if matched_l_paren: + self._match_r_paren(this) + elif not self._match(TokenType.R_BRACKET, expression=this): + self.raise_error("Expecting ]") + else: + this = self.expression(exp.In, this=this, field=self._parse_column()) + + return this + + def _parse_between(self, this: t.Optional[exp.Expression]) -> exp.Between: + symmetric = None + if self._match_text_seq("SYMMETRIC"): + symmetric = True + elif self._match_text_seq("ASYMMETRIC"): + symmetric = False + + low = self._parse_bitwise() + self._match(TokenType.AND) + high = self._parse_bitwise() + + return self.expression( + exp.Between, + this=this, + low=low, + high=high, + symmetric=symmetric, + ) + + def _parse_escape( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if not self._match(TokenType.ESCAPE): + return this + return self.expression( + exp.Escape, this=this, expression=self._parse_string() or self._parse_null() + ) + + def _parse_interval( + self, match_interval: bool = True + ) -> t.Optional[exp.Add | exp.Interval]: + index = self._index + + if not self._match(TokenType.INTERVAL) and match_interval: + return None + + if self._match(TokenType.STRING, advance=False): + this = self._parse_primary() + else: + this = self._parse_term() + + if not this or ( + isinstance(this, exp.Column) + and not this.table + and not this.this.quoted + and self._curr + and self._curr.text.upper() not in self.dialect.VALID_INTERVAL_UNITS + ): + self._retreat(index) + return None + + # handle day-time format interval span with omitted units: + # INTERVAL ' hh[:][mm[:ss[.ff]]]' + interval_span_units_omitted = None + if ( + this + and this.is_string + and self.SUPPORTS_OMITTED_INTERVAL_SPAN_UNIT + and exp.INTERVAL_DAY_TIME_RE.match(this.name) + ): + index = self._index + + # Var "TO" Var + first_unit = self._parse_var(any_token=True, upper=True) + second_unit = None + if first_unit and self._match_text_seq("TO"): + second_unit = self._parse_var(any_token=True, upper=True) + + interval_span_units_omitted = not (first_unit and second_unit) + + self._retreat(index) + + unit = ( + None + if interval_span_units_omitted + else ( + self._parse_function() + or ( + not self._match(TokenType.ALIAS, advance=False) + and self._parse_var(any_token=True, upper=True) + ) + ) + ) + + # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse + # each INTERVAL expression into this canonical form so it's easy to transpile + if this and this.is_number: + this = exp.Literal.string(this.to_py()) + elif this and this.is_string: + parts = exp.INTERVAL_STRING_RE.findall(this.name) + if parts and unit: + # Unconsume the eagerly-parsed unit, since the real unit was part of the string + unit = None + self._retreat(self._index - 1) + + if len(parts) == 1: + this = exp.Literal.string(parts[0][0]) + unit = self.expression(exp.Var, this=parts[0][1].upper()) + + if self.INTERVAL_SPANS and self._match_text_seq("TO"): + unit = self.expression( + exp.IntervalSpan, + this=unit, + expression=self._parse_var(any_token=True, upper=True), + ) + + interval = self.expression(exp.Interval, this=this, unit=unit) + + index = self._index + self._match(TokenType.PLUS) + + # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals + if self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False): + return self.expression( + exp.Add, + this=interval, + expression=self._parse_interval(match_interval=False), + ) + + self._retreat(index) + return interval + + def _parse_bitwise(self) -> t.Optional[exp.Expression]: + this = self._parse_term() + + while True: + if self._match_set(self.BITWISE): + this = self.expression( + self.BITWISE[self._prev.token_type], + this=this, + expression=self._parse_term(), + ) + elif self.dialect.DPIPE_IS_STRING_CONCAT and self._match(TokenType.DPIPE): + this = self.expression( + exp.DPipe, + this=this, + expression=self._parse_term(), + safe=not self.dialect.STRICT_STRING_CONCAT, + ) + elif self._match(TokenType.DQMARK): + this = self.expression( + exp.Coalesce, this=this, expressions=ensure_list(self._parse_term()) + ) + elif self._match_pair(TokenType.LT, TokenType.LT): + this = self.expression( + exp.BitwiseLeftShift, this=this, expression=self._parse_term() + ) + elif self._match_pair(TokenType.GT, TokenType.GT): + this = self.expression( + exp.BitwiseRightShift, this=this, expression=self._parse_term() + ) + else: + break + + return this + + def _parse_term(self) -> t.Optional[exp.Expression]: + this = self._parse_factor() + + while self._match_set(self.TERM): + klass = self.TERM[self._prev.token_type] + comments = self._prev_comments + expression = self._parse_factor() + + this = self.expression( + klass, this=this, comments=comments, expression=expression + ) + + if isinstance(this, exp.Collate): + expr = this.expression + + # Preserve collations such as pg_catalog."default" (Postgres) as columns, otherwise + # fallback to Identifier / Var + if isinstance(expr, exp.Column) and len(expr.parts) == 1: + ident = expr.this + if isinstance(ident, exp.Identifier): + this.set( + "expression", ident if ident.quoted else exp.var(ident.name) + ) + + return this + + def _parse_factor(self) -> t.Optional[exp.Expression]: + parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary + this = self._parse_at_time_zone(parse_method()) + + while self._match_set(self.FACTOR): + klass = self.FACTOR[self._prev.token_type] + comments = self._prev_comments + expression = parse_method() + + if not expression and klass is exp.IntDiv and self._prev.text.isalpha(): + self._retreat(self._index - 1) + return this + + this = self.expression( + klass, this=this, comments=comments, expression=expression + ) + + if isinstance(this, exp.Div): + this.set("typed", self.dialect.TYPED_DIVISION) + this.set("safe", self.dialect.SAFE_DIVISION) + + return this + + def _parse_exponent(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_unary, self.EXPONENT) + + def _parse_unary(self) -> t.Optional[exp.Expression]: + if self._match_set(self.UNARY_PARSERS): + return self.UNARY_PARSERS[self._prev.token_type](self) + return self._parse_type() + + def _parse_type( + self, parse_interval: bool = True, fallback_to_identifier: bool = False + ) -> t.Optional[exp.Expression]: + interval = parse_interval and self._parse_interval() + if interval: + return self._parse_column_ops(interval) + + index = self._index + data_type = self._parse_types(check_func=True, allow_identifiers=False) + + # parse_types() returns a Cast if we parsed BQ's inline constructor () e.g. + # STRUCT(1, 'foo'), which is canonicalized to CAST( AS ) + if isinstance(data_type, exp.Cast): + # This constructor can contain ops directly after it, for instance struct unnesting: + # STRUCT(1, 'foo').* --> CAST(STRUCT(1, 'foo') AS STRUCT 1: + self._retreat(index2) + return self._parse_column_ops(data_type) + + self._retreat(index) + + if fallback_to_identifier: + return self._parse_id_var() + + this = self._parse_column() + return this and self._parse_column_ops(this) + + def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]: + this = self._parse_type() + if not this: + return None + + if isinstance(this, exp.Column) and not this.table: + this = exp.var(this.name.upper()) + + return self.expression( + exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True) + ) + + def _parse_user_defined_type( + self, identifier: exp.Identifier + ) -> t.Optional[exp.Expression]: + type_name = identifier.name + + while self._match(TokenType.DOT): + type_name = f"{type_name}.{self._advance_any() and self._prev.text}" + + return exp.DataType.build(type_name, dialect=self.dialect, udt=True) + + def _parse_types( + self, + check_func: bool = False, + schema: bool = False, + allow_identifiers: bool = True, + ) -> t.Optional[exp.Expression]: + index = self._index + + this: t.Optional[exp.Expression] = None + prefix = self._match_text_seq("SYSUDTLIB", ".") + + if self._match_set(self.TYPE_TOKENS): + type_token = self._prev.token_type + else: + type_token = None + identifier = allow_identifiers and self._parse_id_var( + any_token=False, tokens=(TokenType.VAR,) + ) + if isinstance(identifier, exp.Identifier): + try: + tokens = self.dialect.tokenize(identifier.name) + except TokenError: + tokens = None + + if ( + tokens + and len(tokens) == 1 + and tokens[0].token_type in self.TYPE_TOKENS + ): + type_token = tokens[0].token_type + elif self.dialect.SUPPORTS_USER_DEFINED_TYPES: + this = self._parse_user_defined_type(identifier) + else: + self._retreat(self._index - 1) + return None + else: + return None + + if type_token == TokenType.PSEUDO_TYPE: + return self.expression(exp.PseudoType, this=self._prev.text.upper()) + + if type_token == TokenType.OBJECT_IDENTIFIER: + return self.expression(exp.ObjectIdentifier, this=self._prev.text.upper()) + + # https://materialize.com/docs/sql/types/map/ + if type_token == TokenType.MAP and self._match(TokenType.L_BRACKET): + key_type = self._parse_types( + check_func=check_func, + schema=schema, + allow_identifiers=allow_identifiers, + ) + if not self._match(TokenType.FARROW): + self._retreat(index) + return None + + value_type = self._parse_types( + check_func=check_func, + schema=schema, + allow_identifiers=allow_identifiers, + ) + if not self._match(TokenType.R_BRACKET): + self._retreat(index) + return None + + return exp.DataType( + this=exp.DataType.Type.MAP, + expressions=[key_type, value_type], + nested=True, + prefix=prefix, + ) + + nested = type_token in self.NESTED_TYPE_TOKENS + is_struct = type_token in self.STRUCT_TYPE_TOKENS + is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS + expressions = None + maybe_func = False + + if self._match(TokenType.L_PAREN): + if is_struct: + expressions = self._parse_csv( + lambda: self._parse_struct_types(type_required=True) + ) + elif nested: + expressions = self._parse_csv( + lambda: self._parse_types( + check_func=check_func, + schema=schema, + allow_identifiers=allow_identifiers, + ) + ) + if type_token == TokenType.NULLABLE and len(expressions) == 1: + this = expressions[0] + this.set("nullable", True) + self._match_r_paren() + return this + elif type_token in self.ENUM_TYPE_TOKENS: + expressions = self._parse_csv(self._parse_equality) + elif is_aggregate: + func_or_ident = self._parse_function( + anonymous=True + ) or self._parse_id_var( + any_token=False, tokens=(TokenType.VAR, TokenType.ANY) + ) + if not func_or_ident: + return None + expressions = [func_or_ident] + if self._match(TokenType.COMMA): + expressions.extend( + self._parse_csv( + lambda: self._parse_types( + check_func=check_func, + schema=schema, + allow_identifiers=allow_identifiers, + ) + ) + ) + else: + expressions = self._parse_csv(self._parse_type_size) + + # https://docs.snowflake.com/en/sql-reference/data-types-vector + if type_token == TokenType.VECTOR and len(expressions) == 2: + expressions = self._parse_vector_expressions(expressions) + + if not self._match(TokenType.R_PAREN): + self._retreat(index) + return None + + maybe_func = True + + values: t.Optional[t.List[exp.Expression]] = None + + if nested and self._match(TokenType.LT): + if is_struct: + expressions = self._parse_csv( + lambda: self._parse_struct_types(type_required=True) + ) + else: + expressions = self._parse_csv( + lambda: self._parse_types( + check_func=check_func, + schema=schema, + allow_identifiers=allow_identifiers, + ) + ) + + if not self._match(TokenType.GT): + self.raise_error("Expecting >") + + if self._match_set((TokenType.L_BRACKET, TokenType.L_PAREN)): + values = self._parse_csv(self._parse_disjunction) + if not values and is_struct: + values = None + self._retreat(self._index - 1) + else: + self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN)) + + if type_token in self.TIMESTAMPS: + if self._match_text_seq("WITH", "TIME", "ZONE"): + maybe_func = False + tz_type = ( + exp.DataType.Type.TIMETZ + if type_token in self.TIMES + else exp.DataType.Type.TIMESTAMPTZ + ) + this = exp.DataType(this=tz_type, expressions=expressions) + elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"): + maybe_func = False + this = exp.DataType( + this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions + ) + elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): + maybe_func = False + elif type_token == TokenType.INTERVAL: + unit = self._parse_var(upper=True) + if unit: + if self._match_text_seq("TO"): + unit = exp.IntervalSpan( + this=unit, expression=self._parse_var(upper=True) + ) + + this = self.expression( + exp.DataType, this=self.expression(exp.Interval, unit=unit) + ) + else: + this = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL) + elif type_token == TokenType.VOID: + this = exp.DataType(this=exp.DataType.Type.NULL) + + if maybe_func and check_func: + index2 = self._index + peek = self._parse_string() + + if not peek: + self._retreat(index) + return None + + self._retreat(index2) + + if not this: + if self._match_text_seq("UNSIGNED"): + unsigned_type_token = self.SIGNED_TO_UNSIGNED_TYPE_TOKEN.get(type_token) + if not unsigned_type_token: + self.raise_error(f"Cannot convert {type_token.value} to unsigned.") + + type_token = unsigned_type_token or type_token + + # NULLABLE without parentheses can be a column (Presto/Trino) + if type_token == TokenType.NULLABLE and not expressions: + self._retreat(index) + return None + + this = exp.DataType( + this=exp.DataType.Type[type_token.value], + expressions=expressions, + nested=nested, + prefix=prefix, + ) + + # Empty arrays/structs are allowed + if values is not None: + cls = exp.Struct if is_struct else exp.Array + this = exp.cast(cls(expressions=values), this, copy=False) + + elif expressions: + this.set("expressions", expressions) + + # https://materialize.com/docs/sql/types/list/#type-name + while self._match(TokenType.LIST): + this = exp.DataType( + this=exp.DataType.Type.LIST, expressions=[this], nested=True + ) + + index = self._index + + # Postgres supports the INT ARRAY[3] syntax as a synonym for INT[3] + matched_array = self._match(TokenType.ARRAY) + + while self._curr: + datatype_token = self._prev.token_type + matched_l_bracket = self._match(TokenType.L_BRACKET) + + if (not matched_l_bracket and not matched_array) or ( + datatype_token == TokenType.ARRAY and self._match(TokenType.R_BRACKET) + ): + # Postgres allows casting empty arrays such as ARRAY[]::INT[], + # not to be confused with the fixed size array parsing + break + + matched_array = False + values = self._parse_csv(self._parse_disjunction) or None + if ( + values + and not schema + and ( + not self.dialect.SUPPORTS_FIXED_SIZE_ARRAYS + or datatype_token == TokenType.ARRAY + or not self._match(TokenType.R_BRACKET, advance=False) + ) + ): + # Retreating here means that we should not parse the following values as part of the data type, e.g. in DuckDB + # ARRAY[1] should retreat and instead be parsed into exp.Array in contrast to INT[x][y] which denotes a fixed-size array data type + self._retreat(index) + break + + this = exp.DataType( + this=exp.DataType.Type.ARRAY, + expressions=[this], + values=values, + nested=True, + ) + self._match(TokenType.R_BRACKET) + + if self.TYPE_CONVERTERS and isinstance(this.this, exp.DataType.Type): + converter = self.TYPE_CONVERTERS.get(this.this) + if converter: + this = converter(t.cast(exp.DataType, this)) + + return this + + def _parse_vector_expressions( + self, expressions: t.List[exp.Expression] + ) -> t.List[exp.Expression]: + return [ + exp.DataType.build(expressions[0].name, dialect=self.dialect), + *expressions[1:], + ] + + def _parse_struct_types( + self, type_required: bool = False + ) -> t.Optional[exp.Expression]: + index = self._index + + if ( + self._curr + and self._next + and self._curr.token_type in self.TYPE_TOKENS + and self._next.token_type in self.TYPE_TOKENS + ): + # Takes care of special cases like `STRUCT>` where the identifier is also a + # type token. Without this, the list will be parsed as a type and we'll eventually crash + this = self._parse_id_var() + else: + this = ( + self._parse_type(parse_interval=False, fallback_to_identifier=True) + or self._parse_id_var() + ) + + self._match(TokenType.COLON) + + if ( + type_required + and not isinstance(this, exp.DataType) + and not self._match_set(self.TYPE_TOKENS, advance=False) + ): + self._retreat(index) + return self._parse_types() + + return self._parse_column_def(this) + + def _parse_at_time_zone( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if not self._match_text_seq("AT", "TIME", "ZONE"): + return this + return self._parse_at_time_zone( + self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) + ) + + def _parse_column(self) -> t.Optional[exp.Expression]: + this = self._parse_column_reference() + column = self._parse_column_ops(this) if this else self._parse_bracket(this) + + if self.dialect.SUPPORTS_COLUMN_JOIN_MARKS and column: + column.set("join_mark", self._match(TokenType.JOIN_MARKER)) + + return column + + def _parse_column_reference(self) -> t.Optional[exp.Expression]: + this = self._parse_field() + if ( + not this + and self._match(TokenType.VALUES, advance=False) + and self.VALUES_FOLLOWED_BY_PAREN + and (not self._next or self._next.token_type != TokenType.L_PAREN) + ): + this = self._parse_id_var() + + if isinstance(this, exp.Identifier): + # We bubble up comments from the Identifier to the Column + this = self.expression(exp.Column, comments=this.pop_comments(), this=this) + + return this + + def _parse_colon_as_variant_extract( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + casts = [] + json_path = [] + escape = None + + while self._match(TokenType.COLON): + start_index = self._index + + # Snowflake allows reserved keywords as json keys but advance_any() excludes TokenType.SELECT from any_tokens=True + path = self._parse_column_ops( + self._parse_field(any_token=True, tokens=(TokenType.SELECT,)) + ) + + # The cast :: operator has a lower precedence than the extraction operator :, so + # we rearrange the AST appropriately to avoid casting the JSON path + while isinstance(path, exp.Cast): + casts.append(path.to) + path = path.this + + if casts: + dcolon_offset = next( + i + for i, t in enumerate(self._tokens[start_index:]) + if t.token_type == TokenType.DCOLON + ) + end_token = self._tokens[start_index + dcolon_offset - 1] + else: + end_token = self._prev + + if path: + # Escape single quotes from Snowflake's colon extraction (e.g. col:"a'b") as + # it'll roundtrip to a string literal in GET_PATH + if isinstance(path, exp.Identifier) and path.quoted: + escape = True + + json_path.append(self._find_sql(self._tokens[start_index], end_token)) + + # The VARIANT extract in Snowflake/Databricks is parsed as a JSONExtract; Snowflake uses the json_path in GET_PATH() while + # Databricks transforms it back to the colon/dot notation + if json_path: + json_path_expr = self.dialect.to_json_path( + exp.Literal.string(".".join(json_path)) + ) + + if json_path_expr: + json_path_expr.set("escape", escape) + + this = self.expression( + exp.JSONExtract, + this=this, + expression=json_path_expr, + variant_extract=True, + requires_json=self.JSON_EXTRACT_REQUIRES_JSON_EXPRESSION, + ) + + while casts: + this = self.expression(exp.Cast, this=this, to=casts.pop()) + + return this + + def _parse_dcolon(self) -> t.Optional[exp.Expression]: + return self._parse_types() + + def _parse_column_ops( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + this = self._parse_bracket(this) + + while self._match_set(self.COLUMN_OPERATORS): + op_token = self._prev.token_type + op = self.COLUMN_OPERATORS.get(op_token) + + if op_token in self.CAST_COLUMN_OPERATORS: + field = self._parse_dcolon() + if not field: + self.raise_error("Expected type") + elif op and self._curr: + field = self._parse_column_reference() or self._parse_bitwise() + if isinstance(field, exp.Column) and self._match( + TokenType.DOT, advance=False + ): + field = self._parse_column_ops(field) + else: + field = self._parse_field(any_token=True, anonymous_func=True) + + # Function calls can be qualified, e.g., x.y.FOO() + # This converts the final AST to a series of Dots leading to the function call + # https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference#function_call_rules + if isinstance(field, (exp.Func, exp.Window)) and this: + this = this.transform( + lambda n: n.to_dot(include_dots=False) + if isinstance(n, exp.Column) + else n + ) + + if op: + this = op(self, this, field) + elif isinstance(this, exp.Column) and not this.args.get("catalog"): + this = self.expression( + exp.Column, + comments=this.comments, + this=field, + table=this.this, + db=this.args.get("table"), + catalog=this.args.get("db"), + ) + elif isinstance(field, exp.Window): + # Move the exp.Dot's to the window's function + window_func = self.expression(exp.Dot, this=this, expression=field.this) + field.set("this", window_func) + this = field + else: + this = self.expression(exp.Dot, this=this, expression=field) + + if field and field.comments: + t.cast(exp.Expression, this).add_comments(field.pop_comments()) + + this = self._parse_bracket(this) + + return ( + self._parse_colon_as_variant_extract(this) + if self.COLON_IS_VARIANT_EXTRACT + else this + ) + + def _parse_paren(self) -> t.Optional[exp.Expression]: + if not self._match(TokenType.L_PAREN): + return None + + comments = self._prev_comments + query = self._parse_select() + + if query: + expressions = [query] + else: + expressions = self._parse_expressions() + + this = seq_get(expressions, 0) + + if not this and self._match(TokenType.R_PAREN, advance=False): + this = self.expression(exp.Tuple) + elif isinstance(this, exp.UNWRAPPED_QUERIES): + this = self._parse_subquery(this=this, parse_alias=False) + elif isinstance(this, (exp.Subquery, exp.Values)): + this = self._parse_subquery( + this=self._parse_query_modifiers(self._parse_set_operations(this)), + parse_alias=False, + ) + elif len(expressions) > 1 or self._prev.token_type == TokenType.COMMA: + this = self.expression(exp.Tuple, expressions=expressions) + else: + this = self.expression(exp.Paren, this=this) + + if this: + this.add_comments(comments) + + self._match_r_paren(expression=this) + + if isinstance(this, exp.Paren) and isinstance(this.this, exp.AggFunc): + return self._parse_window(this) + + return this + + def _parse_primary(self) -> t.Optional[exp.Expression]: + if self._match_set(self.PRIMARY_PARSERS): + token_type = self._prev.token_type + primary = self.PRIMARY_PARSERS[token_type](self, self._prev) + + if token_type == TokenType.STRING: + expressions = [primary] + while self._match(TokenType.STRING): + expressions.append(exp.Literal.string(self._prev.text)) + + if len(expressions) > 1: + return self.expression( + exp.Concat, + expressions=expressions, + coalesce=self.dialect.CONCAT_COALESCE, + ) + + return primary + + if self._match_pair(TokenType.DOT, TokenType.NUMBER): + return exp.Literal.number(f"0.{self._prev.text}") + + return self._parse_paren() + + def _parse_field( + self, + any_token: bool = False, + tokens: t.Optional[t.Collection[TokenType]] = None, + anonymous_func: bool = False, + ) -> t.Optional[exp.Expression]: + if anonymous_func: + field = ( + self._parse_function(anonymous=anonymous_func, any_token=any_token) + or self._parse_primary() + ) + else: + field = self._parse_primary() or self._parse_function( + anonymous=anonymous_func, any_token=any_token + ) + return field or self._parse_id_var(any_token=any_token, tokens=tokens) + + def _parse_function( + self, + functions: t.Optional[t.Dict[str, t.Callable]] = None, + anonymous: bool = False, + optional_parens: bool = True, + any_token: bool = False, + ) -> t.Optional[exp.Expression]: + # This allows us to also parse {fn } syntax (Snowflake, MySQL support this) + # See: https://community.snowflake.com/s/article/SQL-Escape-Sequences + fn_syntax = False + if ( + self._match(TokenType.L_BRACE, advance=False) + and self._next + and self._next.text.upper() == "FN" + ): + self._advance(2) + fn_syntax = True + + func = self._parse_function_call( + functions=functions, + anonymous=anonymous, + optional_parens=optional_parens, + any_token=any_token, + ) + + if fn_syntax: + self._match(TokenType.R_BRACE) + + return func + + def _parse_function_args(self, alias: bool = False) -> t.List[exp.Expression]: + return self._parse_csv(lambda: self._parse_lambda(alias=alias)) + + def _parse_function_call( + self, + functions: t.Optional[t.Dict[str, t.Callable]] = None, + anonymous: bool = False, + optional_parens: bool = True, + any_token: bool = False, + ) -> t.Optional[exp.Expression]: + if not self._curr: + return None + + comments = self._curr.comments + prev = self._prev + token = self._curr + token_type = self._curr.token_type + this = self._curr.text + upper = this.upper() + + parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper) + if ( + optional_parens + and parser + and token_type not in self.INVALID_FUNC_NAME_TOKENS + ): + self._advance() + return self._parse_window(parser(self)) + + if not self._next or self._next.token_type != TokenType.L_PAREN: + if optional_parens and token_type in self.NO_PAREN_FUNCTIONS: + self._advance() + return self.expression(self.NO_PAREN_FUNCTIONS[token_type]) + + return None + + if any_token: + if token_type in self.RESERVED_TOKENS: + return None + elif token_type not in self.FUNC_TOKENS: + return None + + self._advance(2) + + parser = self.FUNCTION_PARSERS.get(upper) + if parser and not anonymous: + this = parser(self) + else: + subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) + + if subquery_predicate: + expr = None + if self._curr.token_type in (TokenType.SELECT, TokenType.WITH): + expr = self._parse_select() + self._match_r_paren() + elif prev and prev.token_type in (TokenType.LIKE, TokenType.ILIKE): + # Backtrack one token since we've consumed the L_PAREN here. Instead, we'd like + # to parse "LIKE [ANY | ALL] (...)" as a whole into an exp.Tuple or exp.Paren + self._advance(-1) + expr = self._parse_bitwise() + + if expr: + return self.expression( + subquery_predicate, comments=comments, this=expr + ) + + if functions is None: + functions = self.FUNCTIONS + + function = functions.get(upper) + known_function = function and not anonymous + + alias = not known_function or upper in self.FUNCTIONS_WITH_ALIASED_ARGS + args = self._parse_function_args(alias) + + post_func_comments = self._curr and self._curr.comments + if known_function and post_func_comments: + # If the user-inputted comment "/* sqlglot.anonymous */" is following the function + # call we'll construct it as exp.Anonymous, even if it's "known" + if any( + comment.lstrip().startswith(exp.SQLGLOT_ANONYMOUS) + for comment in post_func_comments + ): + known_function = False + + if alias and known_function: + args = self._kv_to_prop_eq(args) + + if known_function: + func_builder = t.cast(t.Callable, function) + + if "dialect" in func_builder.__code__.co_varnames: + func = func_builder(args, dialect=self.dialect) + else: + func = func_builder(args) + + func = self.validate_expression(func, args) + if self.dialect.PRESERVE_ORIGINAL_NAMES: + func.meta["name"] = this + + this = func + else: + if token_type == TokenType.IDENTIFIER: + this = exp.Identifier(this=this, quoted=True).update_positions( + token + ) + + this = self.expression(exp.Anonymous, this=this, expressions=args) + + this = this.update_positions(token) + + if isinstance(this, exp.Expression): + this.add_comments(comments) + + self._match_r_paren(this) + return self._parse_window(this) + + def _to_prop_eq(self, expression: exp.Expression, index: int) -> exp.Expression: + return expression + + def _kv_to_prop_eq( + self, expressions: t.List[exp.Expression], parse_map: bool = False + ) -> t.List[exp.Expression]: + transformed = [] + + for index, e in enumerate(expressions): + if isinstance(e, self.KEY_VALUE_DEFINITIONS): + if isinstance(e, exp.Alias): + e = self.expression( + exp.PropertyEQ, this=e.args.get("alias"), expression=e.this + ) + + if not isinstance(e, exp.PropertyEQ): + e = self.expression( + exp.PropertyEQ, + this=e.this if parse_map else exp.to_identifier(e.this.name), + expression=e.expression, + ) + + if isinstance(e.this, exp.Column): + e.this.replace(e.this.this) + else: + e = self._to_prop_eq(e, index) + + transformed.append(e) + + return transformed + + def _parse_user_defined_function_expression(self) -> t.Optional[exp.Expression]: + return self._parse_statement() + + def _parse_function_parameter(self) -> t.Optional[exp.Expression]: + return self._parse_column_def(this=self._parse_id_var(), computed_column=False) + + def _parse_user_defined_function( + self, kind: t.Optional[TokenType] = None + ) -> t.Optional[exp.Expression]: + this = self._parse_table_parts(schema=True) + + if not self._match(TokenType.L_PAREN): + return this + + expressions = self._parse_csv(self._parse_function_parameter) + self._match_r_paren() + return self.expression( + exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True + ) + + def _parse_introducer(self, token: Token) -> exp.Introducer | exp.Identifier: + literal = self._parse_primary() + if literal: + return self.expression(exp.Introducer, token=token, expression=literal) + + return self._identifier_expression(token) + + def _parse_session_parameter(self) -> exp.SessionParameter: + kind = None + this = self._parse_id_var() or self._parse_primary() + + if this and self._match(TokenType.DOT): + kind = this.name + this = self._parse_var() or self._parse_primary() + + return self.expression(exp.SessionParameter, this=this, kind=kind) + + def _parse_lambda_arg(self) -> t.Optional[exp.Expression]: + return self._parse_id_var() + + def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]: + index = self._index + + if self._match(TokenType.L_PAREN): + expressions = t.cast( + t.List[t.Optional[exp.Expression]], + self._parse_csv(self._parse_lambda_arg), + ) + + if not self._match(TokenType.R_PAREN): + self._retreat(index) + else: + expressions = [self._parse_lambda_arg()] + + if self._match_set(self.LAMBDAS): + return self.LAMBDAS[self._prev.token_type](self, expressions) + + self._retreat(index) + + this: t.Optional[exp.Expression] + + if self._match(TokenType.DISTINCT): + this = self.expression( + exp.Distinct, expressions=self._parse_csv(self._parse_disjunction) + ) + else: + this = self._parse_select_or_expression(alias=alias) + + return self._parse_limit( + self._parse_order( + self._parse_having_max(self._parse_respect_or_ignore_nulls(this)) + ) + ) + + def _parse_schema( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + index = self._index + if not self._match(TokenType.L_PAREN): + return this + + # Disambiguate between schema and subquery/CTE, e.g. in INSERT INTO table (), + # expr can be of both types + if self._match_set(self.SELECT_START_TOKENS): + self._retreat(index) + return this + args = self._parse_csv( + lambda: self._parse_constraint() or self._parse_field_def() + ) + self._match_r_paren() + return self.expression(exp.Schema, this=this, expressions=args) + + def _parse_field_def(self) -> t.Optional[exp.Expression]: + return self._parse_column_def(self._parse_field(any_token=True)) + + def _parse_column_def( + self, this: t.Optional[exp.Expression], computed_column: bool = True + ) -> t.Optional[exp.Expression]: + # column defs are not really columns, they're identifiers + if isinstance(this, exp.Column): + this = this.this + + if not computed_column: + self._match(TokenType.ALIAS) + + kind = self._parse_types(schema=True) + + if self._match_text_seq("FOR", "ORDINALITY"): + return self.expression(exp.ColumnDef, this=this, ordinality=True) + + constraints: t.List[exp.Expression] = [] + + if (not kind and self._match(TokenType.ALIAS)) or self._match_texts( + ("ALIAS", "MATERIALIZED") + ): + persisted = self._prev.text.upper() == "MATERIALIZED" + constraint_kind = exp.ComputedColumnConstraint( + this=self._parse_disjunction(), + persisted=persisted or self._match_text_seq("PERSISTED"), + data_type=exp.Var(this="AUTO") + if self._match_text_seq("AUTO") + else self._parse_types(), + not_null=self._match_pair(TokenType.NOT, TokenType.NULL), + ) + constraints.append( + self.expression(exp.ColumnConstraint, kind=constraint_kind) + ) + elif ( + kind + and self._match(TokenType.ALIAS, advance=False) + and ( + not self.WRAPPED_TRANSFORM_COLUMN_CONSTRAINT + or (self._next and self._next.token_type == TokenType.L_PAREN) + ) + ): + self._advance() + constraints.append( + self.expression( + exp.ColumnConstraint, + kind=exp.ComputedColumnConstraint( + this=self._parse_disjunction(), + persisted=self._match_texts(("STORED", "VIRTUAL")) + and self._prev.text.upper() == "STORED", + ), + ) + ) + + while True: + constraint = self._parse_column_constraint() + if not constraint: + break + constraints.append(constraint) + + if not kind and not constraints: + return this + + return self.expression( + exp.ColumnDef, this=this, kind=kind, constraints=constraints + ) + + def _parse_auto_increment( + self, + ) -> exp.GeneratedAsIdentityColumnConstraint | exp.AutoIncrementColumnConstraint: + start = None + increment = None + order = None + + if self._match(TokenType.L_PAREN, advance=False): + args = self._parse_wrapped_csv(self._parse_bitwise) + start = seq_get(args, 0) + increment = seq_get(args, 1) + elif self._match_text_seq("START"): + start = self._parse_bitwise() + self._match_text_seq("INCREMENT") + increment = self._parse_bitwise() + if self._match_text_seq("ORDER"): + order = True + elif self._match_text_seq("NOORDER"): + order = False + + if start and increment: + return exp.GeneratedAsIdentityColumnConstraint( + start=start, increment=increment, this=False, order=order + ) + + return exp.AutoIncrementColumnConstraint() + + def _parse_auto_property(self) -> t.Optional[exp.AutoRefreshProperty]: + if not self._match_text_seq("REFRESH"): + self._retreat(self._index - 1) + return None + return self.expression( + exp.AutoRefreshProperty, this=self._parse_var(upper=True) + ) + + def _parse_compress(self) -> exp.CompressColumnConstraint: + if self._match(TokenType.L_PAREN, advance=False): + return self.expression( + exp.CompressColumnConstraint, + this=self._parse_wrapped_csv(self._parse_bitwise), + ) + + return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise()) + + def _parse_generated_as_identity( + self, + ) -> ( + exp.GeneratedAsIdentityColumnConstraint + | exp.ComputedColumnConstraint + | exp.GeneratedAsRowColumnConstraint + ): + if self._match_text_seq("BY", "DEFAULT"): + on_null = self._match_pair(TokenType.ON, TokenType.NULL) + this = self.expression( + exp.GeneratedAsIdentityColumnConstraint, this=False, on_null=on_null + ) + else: + self._match_text_seq("ALWAYS") + this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) + + self._match(TokenType.ALIAS) + + if self._match_text_seq("ROW"): + start = self._match_text_seq("START") + if not start: + self._match(TokenType.END) + hidden = self._match_text_seq("HIDDEN") + return self.expression( + exp.GeneratedAsRowColumnConstraint, start=start, hidden=hidden + ) + + identity = self._match_text_seq("IDENTITY") + + if self._match(TokenType.L_PAREN): + if self._match(TokenType.START_WITH): + this.set("start", self._parse_bitwise()) + if self._match_text_seq("INCREMENT", "BY"): + this.set("increment", self._parse_bitwise()) + if self._match_text_seq("MINVALUE"): + this.set("minvalue", self._parse_bitwise()) + if self._match_text_seq("MAXVALUE"): + this.set("maxvalue", self._parse_bitwise()) + + if self._match_text_seq("CYCLE"): + this.set("cycle", True) + elif self._match_text_seq("NO", "CYCLE"): + this.set("cycle", False) + + if not identity: + this.set("expression", self._parse_range()) + elif not this.args.get("start") and self._match( + TokenType.NUMBER, advance=False + ): + args = self._parse_csv(self._parse_bitwise) + this.set("start", seq_get(args, 0)) + this.set("increment", seq_get(args, 1)) + + self._match_r_paren() + + return this + + def _parse_inline(self) -> exp.InlineLengthColumnConstraint: + self._match_text_seq("LENGTH") + return self.expression( + exp.InlineLengthColumnConstraint, this=self._parse_bitwise() + ) + + def _parse_not_constraint(self) -> t.Optional[exp.Expression]: + if self._match_text_seq("NULL"): + return self.expression(exp.NotNullColumnConstraint) + if self._match_text_seq("CASESPECIFIC"): + return self.expression(exp.CaseSpecificColumnConstraint, not_=True) + if self._match_text_seq("FOR", "REPLICATION"): + return self.expression(exp.NotForReplicationColumnConstraint) + + # Unconsume the `NOT` token + self._retreat(self._index - 1) + return None + + def _parse_column_constraint(self) -> t.Optional[exp.Expression]: + this = self._match(TokenType.CONSTRAINT) and self._parse_id_var() + + procedure_option_follows = ( + self._match(TokenType.WITH, advance=False) + and self._next + and self._next.text.upper() in self.PROCEDURE_OPTIONS + ) + + if not procedure_option_follows and self._match_texts(self.CONSTRAINT_PARSERS): + return self.expression( + exp.ColumnConstraint, + this=this, + kind=self.CONSTRAINT_PARSERS[self._prev.text.upper()](self), + ) + + return this + + def _parse_constraint(self) -> t.Optional[exp.Expression]: + if not self._match(TokenType.CONSTRAINT): + return self._parse_unnamed_constraint( + constraints=self.SCHEMA_UNNAMED_CONSTRAINTS + ) + + return self.expression( + exp.Constraint, + this=self._parse_id_var(), + expressions=self._parse_unnamed_constraints(), + ) + + def _parse_unnamed_constraints(self) -> t.List[exp.Expression]: + constraints = [] + while True: + constraint = self._parse_unnamed_constraint() or self._parse_function() + if not constraint: + break + constraints.append(constraint) + + return constraints + + def _parse_unnamed_constraint( + self, constraints: t.Optional[t.Collection[str]] = None + ) -> t.Optional[exp.Expression]: + if self._match(TokenType.IDENTIFIER, advance=False) or not self._match_texts( + constraints or self.CONSTRAINT_PARSERS + ): + return None + + constraint = self._prev.text.upper() + if constraint not in self.CONSTRAINT_PARSERS: + self.raise_error(f"No parser found for schema constraint {constraint}.") + + return self.CONSTRAINT_PARSERS[constraint](self) + + def _parse_unique_key(self) -> t.Optional[exp.Expression]: + return self._parse_id_var(any_token=False) + + def _parse_unique(self) -> exp.UniqueColumnConstraint: + self._match_texts(("KEY", "INDEX")) + return self.expression( + exp.UniqueColumnConstraint, + nulls=self._match_text_seq("NULLS", "NOT", "DISTINCT"), + this=self._parse_schema(self._parse_unique_key()), + index_type=self._match(TokenType.USING) + and self._advance_any() + and self._prev.text, + on_conflict=self._parse_on_conflict(), + options=self._parse_key_constraint_options(), + ) + + def _parse_key_constraint_options(self) -> t.List[str]: + options = [] + while True: + if not self._curr: + break + + if self._match(TokenType.ON): + action = None + on = self._advance_any() and self._prev.text + + if self._match_text_seq("NO", "ACTION"): + action = "NO ACTION" + elif self._match_text_seq("CASCADE"): + action = "CASCADE" + elif self._match_text_seq("RESTRICT"): + action = "RESTRICT" + elif self._match_pair(TokenType.SET, TokenType.NULL): + action = "SET NULL" + elif self._match_pair(TokenType.SET, TokenType.DEFAULT): + action = "SET DEFAULT" + else: + self.raise_error("Invalid key constraint") + + options.append(f"ON {on} {action}") + else: + var = self._parse_var_from_options( + self.KEY_CONSTRAINT_OPTIONS, raise_unmatched=False + ) + if not var: + break + options.append(var.name) + + return options + + def _parse_references(self, match: bool = True) -> t.Optional[exp.Reference]: + if match and not self._match(TokenType.REFERENCES): + return None + + expressions = None + this = self._parse_table(schema=True) + options = self._parse_key_constraint_options() + return self.expression( + exp.Reference, this=this, expressions=expressions, options=options + ) + + def _parse_foreign_key(self) -> exp.ForeignKey: + expressions = ( + self._parse_wrapped_id_vars() + if not self._match(TokenType.REFERENCES, advance=False) + else None + ) + reference = self._parse_references() + on_options = {} + + while self._match(TokenType.ON): + if not self._match_set((TokenType.DELETE, TokenType.UPDATE)): + self.raise_error("Expected DELETE or UPDATE") + + kind = self._prev.text.lower() + + if self._match_text_seq("NO", "ACTION"): + action = "NO ACTION" + elif self._match(TokenType.SET): + self._match_set((TokenType.NULL, TokenType.DEFAULT)) + action = "SET " + self._prev.text.upper() + else: + self._advance() + action = self._prev.text.upper() + + on_options[kind] = action + + return self.expression( + exp.ForeignKey, + expressions=expressions, + reference=reference, + options=self._parse_key_constraint_options(), + **on_options, # type: ignore + ) + + def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: + return self._parse_field() + + def _parse_period_for_system_time( + self, + ) -> t.Optional[exp.PeriodForSystemTimeConstraint]: + if not self._match(TokenType.TIMESTAMP_SNAPSHOT): + self._retreat(self._index - 1) + return None + + id_vars = self._parse_wrapped_id_vars() + return self.expression( + exp.PeriodForSystemTimeConstraint, + this=seq_get(id_vars, 0), + expression=seq_get(id_vars, 1), + ) + + def _parse_primary_key( + self, wrapped_optional: bool = False, in_props: bool = False + ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: + desc = ( + self._match_set((TokenType.ASC, TokenType.DESC)) + and self._prev.token_type == TokenType.DESC + ) + + this = None + if ( + self._curr.text.upper() not in self.CONSTRAINT_PARSERS + and self._next + and self._next.token_type == TokenType.L_PAREN + ): + this = self._parse_id_var() + + if not in_props and not self._match(TokenType.L_PAREN, advance=False): + return self.expression( + exp.PrimaryKeyColumnConstraint, + desc=desc, + options=self._parse_key_constraint_options(), + ) + + expressions = self._parse_wrapped_csv( + self._parse_primary_key_part, optional=wrapped_optional + ) + + return self.expression( + exp.PrimaryKey, + this=this, + expressions=expressions, + include=self._parse_index_params(), + options=self._parse_key_constraint_options(), + ) + + def _parse_bracket_key_value( + self, is_map: bool = False + ) -> t.Optional[exp.Expression]: + return self._parse_slice( + self._parse_alias(self._parse_disjunction(), explicit=True) + ) + + def _parse_odbc_datetime_literal(self) -> exp.Expression: + """ + Parses a datetime column in ODBC format. We parse the column into the corresponding + types, for example `{d'yyyy-mm-dd'}` will be parsed as a `Date` column, exactly the + same as we did for `DATE('yyyy-mm-dd')`. + + Reference: + https://learn.microsoft.com/en-us/sql/odbc/reference/develop-app/date-time-and-timestamp-literals + """ + self._match(TokenType.VAR) + exp_class = self.ODBC_DATETIME_LITERALS[self._prev.text.lower()] + expression = self.expression(exp_class=exp_class, this=self._parse_string()) + if not self._match(TokenType.R_BRACE): + self.raise_error("Expected }") + return expression + + def _parse_bracket( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): + return this + + if self.MAP_KEYS_ARE_ARBITRARY_EXPRESSIONS: + map_token = seq_get(self._tokens, self._index - 2) + parse_map = map_token is not None and map_token.text.upper() == "MAP" + else: + parse_map = False + + bracket_kind = self._prev.token_type + if ( + bracket_kind == TokenType.L_BRACE + and self._curr + and self._curr.token_type == TokenType.VAR + and self._curr.text.lower() in self.ODBC_DATETIME_LITERALS + ): + return self._parse_odbc_datetime_literal() + + expressions = self._parse_csv( + lambda: self._parse_bracket_key_value( + is_map=bracket_kind == TokenType.L_BRACE + ) + ) + + if bracket_kind == TokenType.L_BRACKET and not self._match(TokenType.R_BRACKET): + self.raise_error("Expected ]") + elif bracket_kind == TokenType.L_BRACE and not self._match(TokenType.R_BRACE): + self.raise_error("Expected }") + + # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs + if bracket_kind == TokenType.L_BRACE: + this = self.expression( + exp.Struct, + expressions=self._kv_to_prop_eq( + expressions=expressions, parse_map=parse_map + ), + ) + elif not this: + this = build_array_constructor( + exp.Array, + args=expressions, + bracket_kind=bracket_kind, + dialect=self.dialect, + ) + else: + constructor_type = self.ARRAY_CONSTRUCTORS.get(this.name.upper()) + if constructor_type: + return build_array_constructor( + constructor_type, + args=expressions, + bracket_kind=bracket_kind, + dialect=self.dialect, + ) + + expressions = apply_index_offset( + this, expressions, -self.dialect.INDEX_OFFSET, dialect=self.dialect + ) + this = self.expression( + exp.Bracket, + this=this, + expressions=expressions, + comments=this.pop_comments(), + ) + + self._add_comments(this) + return self._parse_bracket(this) + + def _parse_slice( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if not self._match(TokenType.COLON): + return this + + if self._match_pair(TokenType.DASH, TokenType.COLON, advance=False): + self._advance() + end: t.Optional[exp.Expression] = -exp.Literal.number("1") + else: + end = self._parse_unary() + step = self._parse_unary() if self._match(TokenType.COLON) else None + return self.expression(exp.Slice, this=this, expression=end, step=step) + + def _parse_case(self) -> t.Optional[exp.Expression]: + if self._match(TokenType.DOT, advance=False): + # Avoid raising on valid expressions like case.*, supported by, e.g., spark & snowflake + self._retreat(self._index - 1) + return None + + ifs = [] + default = None + + comments = self._prev_comments + expression = self._parse_disjunction() + + while self._match(TokenType.WHEN): + this = self._parse_disjunction() + self._match(TokenType.THEN) + then = self._parse_disjunction() + ifs.append(self.expression(exp.If, this=this, true=then)) + + if self._match(TokenType.ELSE): + default = self._parse_disjunction() + + if not self._match(TokenType.END): + if ( + isinstance(default, exp.Interval) + and default.this.sql().upper() == "END" + ): + default = exp.column("interval") + else: + self.raise_error("Expected END after CASE", self._prev) + + return self.expression( + exp.Case, comments=comments, this=expression, ifs=ifs, default=default + ) + + def _parse_if(self) -> t.Optional[exp.Expression]: + if self._match(TokenType.L_PAREN): + args = self._parse_csv( + lambda: self._parse_alias(self._parse_assignment(), explicit=True) + ) + this = self.validate_expression(exp.If.from_arg_list(args), args) + self._match_r_paren() + else: + index = self._index - 1 + + if self.NO_PAREN_IF_COMMANDS and index == 0: + return self._parse_as_command(self._prev) + + condition = self._parse_disjunction() + + if not condition: + self._retreat(index) + return None + + self._match(TokenType.THEN) + true = self._parse_disjunction() + false = self._parse_disjunction() if self._match(TokenType.ELSE) else None + self._match(TokenType.END) + this = self.expression(exp.If, this=condition, true=true, false=false) + + return this + + def _parse_next_value_for(self) -> t.Optional[exp.Expression]: + if not self._match_text_seq("VALUE", "FOR"): + self._retreat(self._index - 1) + return None + + return self.expression( + exp.NextValueFor, + this=self._parse_column(), + order=self._match(TokenType.OVER) + and self._parse_wrapped(self._parse_order), + ) + + def _parse_extract(self) -> exp.Extract | exp.Anonymous: + this = self._parse_function() or self._parse_var_or_string(upper=True) + + if self._match(TokenType.FROM): + return self.expression( + exp.Extract, this=this, expression=self._parse_bitwise() + ) + + if not self._match(TokenType.COMMA): + self.raise_error("Expected FROM or comma after EXTRACT", self._prev) + + return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) + + def _parse_gap_fill(self) -> exp.GapFill: + self._match(TokenType.TABLE) + this = self._parse_table() + + self._match(TokenType.COMMA) + args = [this, *self._parse_csv(self._parse_lambda)] + + gap_fill = exp.GapFill.from_arg_list(args) + return self.validate_expression(gap_fill, args) + + def _parse_cast( + self, strict: bool, safe: t.Optional[bool] = None + ) -> exp.Expression: + this = self._parse_disjunction() + + if not self._match(TokenType.ALIAS): + if self._match(TokenType.COMMA): + return self.expression( + exp.CastToStrType, this=this, to=self._parse_string() + ) + + self.raise_error("Expected AS after CAST") + + fmt = None + to = self._parse_types() + + default = self._match(TokenType.DEFAULT) + if default: + default = self._parse_bitwise() + self._match_text_seq("ON", "CONVERSION", "ERROR") + + if self._match_set((TokenType.FORMAT, TokenType.COMMA)): + fmt_string = self._parse_string() + fmt = self._parse_at_time_zone(fmt_string) + + if not to: + to = exp.DataType.build(exp.DataType.Type.UNKNOWN) + if to.this in exp.DataType.TEMPORAL_TYPES: + this = self.expression( + exp.StrToDate + if to.this == exp.DataType.Type.DATE + else exp.StrToTime, + this=this, + format=exp.Literal.string( + format_time( + fmt_string.this if fmt_string else "", + self.dialect.FORMAT_MAPPING or self.dialect.TIME_MAPPING, + self.dialect.FORMAT_TRIE or self.dialect.TIME_TRIE, + ) + ), + safe=safe, + ) + + if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime): + this.set("zone", fmt.args["zone"]) + return this + elif not to: + self.raise_error("Expected TYPE after CAST") + elif isinstance(to, exp.Identifier): + to = exp.DataType.build(to.name, dialect=self.dialect, udt=True) + elif to.this == exp.DataType.Type.CHAR: + if self._match(TokenType.CHARACTER_SET): + to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) + + return self.build_cast( + strict=strict, + this=this, + to=to, + format=fmt, + safe=safe, + action=self._parse_var_from_options( + self.CAST_ACTIONS, raise_unmatched=False + ), + default=default, + ) + + def _parse_string_agg(self) -> exp.GroupConcat: + if self._match(TokenType.DISTINCT): + args: t.List[t.Optional[exp.Expression]] = [ + self.expression(exp.Distinct, expressions=[self._parse_disjunction()]) + ] + if self._match(TokenType.COMMA): + args.extend(self._parse_csv(self._parse_disjunction)) + else: + args = self._parse_csv(self._parse_disjunction) # type: ignore + + if self._match_text_seq("ON", "OVERFLOW"): + # trino: LISTAGG(expression [, separator] [ON OVERFLOW overflow_behavior]) + if self._match_text_seq("ERROR"): + on_overflow: t.Optional[exp.Expression] = exp.var("ERROR") + else: + self._match_text_seq("TRUNCATE") + on_overflow = self.expression( + exp.OverflowTruncateBehavior, + this=self._parse_string(), + with_count=( + self._match_text_seq("WITH", "COUNT") + or not self._match_text_seq("WITHOUT", "COUNT") + ), + ) + else: + on_overflow = None + + index = self._index + if not self._match(TokenType.R_PAREN) and args: + # postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]]) + # bigquery: STRING_AGG([DISTINCT] expression [, separator] [ORDER BY key [{ASC | DESC}] [, ... ]] [LIMIT n]) + # The order is parsed through `this` as a canonicalization for WITHIN GROUPs + args[0] = self._parse_limit(this=self._parse_order(this=args[0])) + return self.expression( + exp.GroupConcat, this=args[0], separator=seq_get(args, 1) + ) + + # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY [ASC | DESC]). + # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that + # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them. + if not self._match_text_seq("WITHIN", "GROUP"): + self._retreat(index) + return self.validate_expression(exp.GroupConcat.from_arg_list(args), args) + + # The corresponding match_r_paren will be called in parse_function (caller) + self._match_l_paren() + + return self.expression( + exp.GroupConcat, + this=self._parse_order(this=seq_get(args, 0)), + separator=seq_get(args, 1), + on_overflow=on_overflow, + ) + + def _parse_convert( + self, strict: bool, safe: t.Optional[bool] = None + ) -> t.Optional[exp.Expression]: + this = self._parse_bitwise() + + if self._match(TokenType.USING): + to: t.Optional[exp.Expression] = self.expression( + exp.CharacterSet, this=self._parse_var() + ) + elif self._match(TokenType.COMMA): + to = self._parse_types() + else: + to = None + + return self.build_cast(strict=strict, this=this, to=to, safe=safe) + + def _parse_xml_table(self) -> exp.XMLTable: + namespaces = None + passing = None + columns = None + + if self._match_text_seq("XMLNAMESPACES", "("): + namespaces = self._parse_xml_namespace() + self._match_text_seq(")", ",") + + this = self._parse_string() + + if self._match_text_seq("PASSING"): + # The BY VALUE keywords are optional and are provided for semantic clarity + self._match_text_seq("BY", "VALUE") + passing = self._parse_csv(self._parse_column) + + by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") + + if self._match_text_seq("COLUMNS"): + columns = self._parse_csv(self._parse_field_def) + + return self.expression( + exp.XMLTable, + this=this, + namespaces=namespaces, + passing=passing, + columns=columns, + by_ref=by_ref, + ) + + def _parse_xml_namespace(self) -> t.List[exp.XMLNamespace]: + namespaces = [] + + while True: + if self._match(TokenType.DEFAULT): + uri = self._parse_string() + else: + uri = self._parse_alias(self._parse_string()) + namespaces.append(self.expression(exp.XMLNamespace, this=uri)) + if not self._match(TokenType.COMMA): + break + + return namespaces + + def _parse_decode(self) -> t.Optional[exp.Decode | exp.DecodeCase]: + args = self._parse_csv(self._parse_disjunction) + + if len(args) < 3: + return self.expression( + exp.Decode, this=seq_get(args, 0), charset=seq_get(args, 1) + ) + + return self.expression(exp.DecodeCase, expressions=args) + + def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]: + self._match_text_seq("KEY") + key = self._parse_column() + self._match_set(self.JSON_KEY_VALUE_SEPARATOR_TOKENS) + self._match_text_seq("VALUE") + value = self._parse_bitwise() + + if not key and not value: + return None + return self.expression(exp.JSONKeyValue, this=key, expression=value) + + def _parse_format_json( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if not this or not self._match_text_seq("FORMAT", "JSON"): + return this + + return self.expression(exp.FormatJson, this=this) + + def _parse_on_condition(self) -> t.Optional[exp.OnCondition]: + # MySQL uses "X ON EMPTY Y ON ERROR" (e.g. JSON_VALUE) while Oracle uses the opposite (e.g. JSON_EXISTS) + if self.dialect.ON_CONDITION_EMPTY_BEFORE_ERROR: + empty = self._parse_on_handling("EMPTY", *self.ON_CONDITION_TOKENS) + error = self._parse_on_handling("ERROR", *self.ON_CONDITION_TOKENS) + else: + error = self._parse_on_handling("ERROR", *self.ON_CONDITION_TOKENS) + empty = self._parse_on_handling("EMPTY", *self.ON_CONDITION_TOKENS) + + null = self._parse_on_handling("NULL", *self.ON_CONDITION_TOKENS) + + if not empty and not error and not null: + return None + + return self.expression( + exp.OnCondition, + empty=empty, + error=error, + null=null, + ) + + def _parse_on_handling( + self, on: str, *values: str + ) -> t.Optional[str] | t.Optional[exp.Expression]: + # Parses the "X ON Y" or "DEFAULT ON Y syntax, e.g. NULL ON NULL (Oracle, T-SQL, MySQL) + for value in values: + if self._match_text_seq(value, "ON", on): + return f"{value} ON {on}" + + index = self._index + if self._match(TokenType.DEFAULT): + default_value = self._parse_bitwise() + if self._match_text_seq("ON", on): + return default_value + + self._retreat(index) + + return None + + @t.overload + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: + ... + + @t.overload + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: + ... + + def _parse_json_object(self, agg=False): + star = self._parse_star() + expressions = ( + [star] + if star + else self._parse_csv( + lambda: self._parse_format_json(self._parse_json_key_value()) + ) + ) + null_handling = self._parse_on_handling("NULL", "NULL", "ABSENT") + + unique_keys = None + if self._match_text_seq("WITH", "UNIQUE"): + unique_keys = True + elif self._match_text_seq("WITHOUT", "UNIQUE"): + unique_keys = False + + self._match_text_seq("KEYS") + + return_type = self._match_text_seq("RETURNING") and self._parse_format_json( + self._parse_type() + ) + encoding = self._match_text_seq("ENCODING") and self._parse_var() + + return self.expression( + exp.JSONObjectAgg if agg else exp.JSONObject, + expressions=expressions, + null_handling=null_handling, + unique_keys=unique_keys, + return_type=return_type, + encoding=encoding, + ) + + # Note: this is currently incomplete; it only implements the "JSON_value_column" part + def _parse_json_column_def(self) -> exp.JSONColumnDef: + if not self._match_text_seq("NESTED"): + this = self._parse_id_var() + ordinality = self._match_pair(TokenType.FOR, TokenType.ORDINALITY) + kind = self._parse_types(allow_identifiers=False) + nested = None + else: + this = None + ordinality = None + kind = None + nested = True + + path = self._match_text_seq("PATH") and self._parse_string() + nested_schema = nested and self._parse_json_schema() + + return self.expression( + exp.JSONColumnDef, + this=this, + kind=kind, + path=path, + nested_schema=nested_schema, + ordinality=ordinality, + ) + + def _parse_json_schema(self) -> exp.JSONSchema: + self._match_text_seq("COLUMNS") + return self.expression( + exp.JSONSchema, + expressions=self._parse_wrapped_csv( + self._parse_json_column_def, optional=True + ), + ) + + def _parse_json_table(self) -> exp.JSONTable: + this = self._parse_format_json(self._parse_bitwise()) + path = self._match(TokenType.COMMA) and self._parse_string() + error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL") + empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL") + schema = self._parse_json_schema() + + return exp.JSONTable( + this=this, + schema=schema, + path=path, + error_handling=error_handling, + empty_handling=empty_handling, + ) + + def _parse_match_against(self) -> exp.MatchAgainst: + if self._match_text_seq("TABLE"): + # parse SingleStore MATCH(TABLE ...) syntax + # https://docs.singlestore.com/cloud/reference/sql-reference/full-text-search-functions/match/ + expressions = [] + table = self._parse_table() + if table: + expressions = [table] + else: + expressions = self._parse_csv(self._parse_column) + + self._match_text_seq(")", "AGAINST", "(") + + this = self._parse_string() + + if self._match_text_seq("IN", "NATURAL", "LANGUAGE", "MODE"): + modifier = "IN NATURAL LANGUAGE MODE" + if self._match_text_seq("WITH", "QUERY", "EXPANSION"): + modifier = f"{modifier} WITH QUERY EXPANSION" + elif self._match_text_seq("IN", "BOOLEAN", "MODE"): + modifier = "IN BOOLEAN MODE" + elif self._match_text_seq("WITH", "QUERY", "EXPANSION"): + modifier = "WITH QUERY EXPANSION" + else: + modifier = None + + return self.expression( + exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier + ) + + # https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16 + def _parse_open_json(self) -> exp.OpenJSON: + this = self._parse_bitwise() + path = self._match(TokenType.COMMA) and self._parse_string() + + def _parse_open_json_column_def() -> exp.OpenJSONColumnDef: + this = self._parse_field(any_token=True) + kind = self._parse_types() + path = self._parse_string() + as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON) + + return self.expression( + exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json + ) + + expressions = None + if self._match_pair(TokenType.R_PAREN, TokenType.WITH): + self._match_l_paren() + expressions = self._parse_csv(_parse_open_json_column_def) + + return self.expression( + exp.OpenJSON, this=this, path=path, expressions=expressions + ) + + def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: + args = self._parse_csv(self._parse_bitwise) + + if self._match(TokenType.IN): + return self.expression( + exp.StrPosition, this=self._parse_bitwise(), substr=seq_get(args, 0) + ) + + if haystack_first: + haystack = seq_get(args, 0) + needle = seq_get(args, 1) + else: + haystack = seq_get(args, 1) + needle = seq_get(args, 0) + + return self.expression( + exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2) + ) + + def _parse_join_hint(self, func_name: str) -> exp.JoinHint: + args = self._parse_csv(self._parse_table) + return exp.JoinHint(this=func_name.upper(), expressions=args) + + def _parse_substring(self) -> exp.Substring: + # Postgres supports the form: substring(string [from int] [for int]) + # (despite being undocumented, the reverse order also works) + # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 + + args = t.cast( + t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_bitwise) + ) + + start, length = None, None + + while self._curr: + if self._match(TokenType.FROM): + start = self._parse_bitwise() + elif self._match(TokenType.FOR): + if not start: + start = exp.Literal.number(1) + length = self._parse_bitwise() + else: + break + + if start: + args.append(start) + if length: + args.append(length) + + return self.validate_expression(exp.Substring.from_arg_list(args), args) + + def _parse_trim(self) -> exp.Trim: + # https://www.w3resource.com/sql/character-functions/trim.php + # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html + + position = None + collation = None + expression = None + + if self._match_texts(self.TRIM_TYPES): + position = self._prev.text.upper() + + this = self._parse_bitwise() + if self._match_set((TokenType.FROM, TokenType.COMMA)): + invert_order = ( + self._prev.token_type == TokenType.FROM or self.TRIM_PATTERN_FIRST + ) + expression = self._parse_bitwise() + + if invert_order: + this, expression = expression, this + + if self._match(TokenType.COLLATE): + collation = self._parse_bitwise() + + return self.expression( + exp.Trim, + this=this, + position=position, + expression=expression, + collation=collation, + ) + + def _parse_window_clause(self) -> t.Optional[t.List[exp.Expression]]: + return self._match(TokenType.WINDOW) and self._parse_csv( + self._parse_named_window + ) + + def _parse_named_window(self) -> t.Optional[exp.Expression]: + return self._parse_window(self._parse_id_var(), alias=True) + + def _parse_respect_or_ignore_nulls( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if self._match_text_seq("IGNORE", "NULLS"): + return self.expression(exp.IgnoreNulls, this=this) + if self._match_text_seq("RESPECT", "NULLS"): + return self.expression(exp.RespectNulls, this=this) + return this + + def _parse_having_max( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if self._match(TokenType.HAVING): + self._match_texts(("MAX", "MIN")) + max = self._prev.text.upper() != "MIN" + return self.expression( + exp.HavingMax, this=this, expression=self._parse_column(), max=max + ) + + return this + + def _parse_window( + self, this: t.Optional[exp.Expression], alias: bool = False + ) -> t.Optional[exp.Expression]: + func = this + comments = func.comments if isinstance(func, exp.Expression) else None + + # T-SQL allows the OVER (...) syntax after WITHIN GROUP. + # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 + if self._match_text_seq("WITHIN", "GROUP"): + order = self._parse_wrapped(self._parse_order) + this = self.expression(exp.WithinGroup, this=this, expression=order) + + if self._match_pair(TokenType.FILTER, TokenType.L_PAREN): + self._match(TokenType.WHERE) + this = self.expression( + exp.Filter, + this=this, + expression=self._parse_where(skip_where_token=True), + ) + self._match_r_paren() + + # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER + # Some dialects choose to implement and some do not. + # https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html + + # There is some code above in _parse_lambda that handles + # SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ... + + # The below changes handle + # SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ... + + # Oracle allows both formats + # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html) + # and Snowflake chose to do the same for familiarity + # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes + if isinstance(this, exp.AggFunc): + ignore_respect = this.find(exp.IgnoreNulls, exp.RespectNulls) + + if ignore_respect and ignore_respect is not this: + ignore_respect.replace(ignore_respect.this) + this = self.expression(ignore_respect.__class__, this=this) + + this = self._parse_respect_or_ignore_nulls(this) + + # bigquery select from window x AS (partition by ...) + if alias: + over = None + self._match(TokenType.ALIAS) + elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS): + return this + else: + over = self._prev.text.upper() + + if comments and isinstance(func, exp.Expression): + func.pop_comments() + + if not self._match(TokenType.L_PAREN): + return self.expression( + exp.Window, + comments=comments, + this=this, + alias=self._parse_id_var(False), + over=over, + ) + + window_alias = self._parse_id_var( + any_token=False, tokens=self.WINDOW_ALIAS_TOKENS + ) + + first = self._match(TokenType.FIRST) + if self._match_text_seq("LAST"): + first = False + + partition, order = self._parse_partition_and_order() + kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text + + if kind: + self._match(TokenType.BETWEEN) + start = self._parse_window_spec() + + end = self._parse_window_spec() if self._match(TokenType.AND) else {} + exclude = ( + self._parse_var_from_options(self.WINDOW_EXCLUDE_OPTIONS) + if self._match_text_seq("EXCLUDE") + else None + ) + + spec = self.expression( + exp.WindowSpec, + kind=kind, + start=start["value"], + start_side=start["side"], + end=end.get("value"), + end_side=end.get("side"), + exclude=exclude, + ) + else: + spec = None + + self._match_r_paren() + + window = self.expression( + exp.Window, + comments=comments, + this=this, + partition_by=partition, + order=order, + spec=spec, + alias=window_alias, + over=over, + first=first, + ) + + # This covers Oracle's FIRST/LAST syntax: aggregate KEEP (...) OVER (...) + if self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS, advance=False): + return self._parse_window(window, alias=alias) + + return window + + def _parse_partition_and_order( + self, + ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]: + return self._parse_partition_by(), self._parse_order() + + def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: + self._match(TokenType.BETWEEN) + + return { + "value": ( + (self._match_text_seq("UNBOUNDED") and "UNBOUNDED") + or (self._match_text_seq("CURRENT", "ROW") and "CURRENT ROW") + or self._parse_bitwise() + ), + "side": self._match_texts(self.WINDOW_SIDES) and self._prev.text, + } + + def _parse_alias( + self, this: t.Optional[exp.Expression], explicit: bool = False + ) -> t.Optional[exp.Expression]: + # In some dialects, LIMIT and OFFSET can act as both identifiers and keywords (clauses) + # so this section tries to parse the clause version and if it fails, it treats the token + # as an identifier (alias) + if self._can_parse_limit_or_offset(): + return this + + any_token = self._match(TokenType.ALIAS) + comments = self._prev_comments or [] + + if explicit and not any_token: + return this + + if self._match(TokenType.L_PAREN): + aliases = self.expression( + exp.Aliases, + comments=comments, + this=this, + expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), + ) + self._match_r_paren(aliases) + return aliases + + alias = self._parse_id_var(any_token, tokens=self.ALIAS_TOKENS) or ( + self.STRING_ALIASES and self._parse_string_as_identifier() + ) + + if alias: + comments.extend(alias.pop_comments()) + this = self.expression(exp.Alias, comments=comments, this=this, alias=alias) + column = this.this + + # Moves the comment next to the alias in `expr /* comment */ AS alias` + if not this.comments and column and column.comments: + this.comments = column.pop_comments() + + return this + + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: + expression = self._parse_identifier() + if not expression and ( + (any_token and self._advance_any()) + or self._match_set(tokens or self.ID_VAR_TOKENS) + ): + quoted = self._prev.token_type == TokenType.STRING + expression = self._identifier_expression(quoted=quoted) + + return expression + + def _parse_string(self) -> t.Optional[exp.Expression]: + if self._match_set(self.STRING_PARSERS): + return self.STRING_PARSERS[self._prev.token_type](self, self._prev) + return self._parse_placeholder() + + def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]: + output = exp.to_identifier( + self._match(TokenType.STRING) and self._prev.text, quoted=True + ) + if output: + output.update_positions(self._prev) + return output + + def _parse_number(self) -> t.Optional[exp.Expression]: + if self._match_set(self.NUMERIC_PARSERS): + return self.NUMERIC_PARSERS[self._prev.token_type](self, self._prev) + return self._parse_placeholder() + + def _parse_identifier(self) -> t.Optional[exp.Expression]: + if self._match(TokenType.IDENTIFIER): + return self._identifier_expression(quoted=True) + return self._parse_placeholder() + + def _parse_var( + self, + any_token: bool = False, + tokens: t.Optional[t.Collection[TokenType]] = None, + upper: bool = False, + ) -> t.Optional[exp.Expression]: + if ( + (any_token and self._advance_any()) + or self._match(TokenType.VAR) + or (self._match_set(tokens) if tokens else False) + ): + return self.expression( + exp.Var, this=self._prev.text.upper() if upper else self._prev.text + ) + return self._parse_placeholder() + + def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]: + if self._curr and ( + ignore_reserved or self._curr.token_type not in self.RESERVED_TOKENS + ): + self._advance() + return self._prev + return None + + def _parse_var_or_string(self, upper: bool = False) -> t.Optional[exp.Expression]: + return self._parse_string() or self._parse_var(any_token=True, upper=upper) + + def _parse_primary_or_var(self) -> t.Optional[exp.Expression]: + return self._parse_primary() or self._parse_var(any_token=True) + + def _parse_null(self) -> t.Optional[exp.Expression]: + if self._match_set((TokenType.NULL, TokenType.UNKNOWN)): + return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) + return self._parse_placeholder() + + def _parse_boolean(self) -> t.Optional[exp.Expression]: + if self._match(TokenType.TRUE): + return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) + if self._match(TokenType.FALSE): + return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) + return self._parse_placeholder() + + def _parse_star(self) -> t.Optional[exp.Expression]: + if self._match(TokenType.STAR): + return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) + return self._parse_placeholder() + + def _parse_parameter(self) -> exp.Parameter: + this = self._parse_identifier() or self._parse_primary_or_var() + return self.expression(exp.Parameter, this=this) + + def _parse_placeholder(self) -> t.Optional[exp.Expression]: + if self._match_set(self.PLACEHOLDER_PARSERS): + placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self) + if placeholder: + return placeholder + self._advance(-1) + return None + + def _parse_star_op(self, *keywords: str) -> t.Optional[t.List[exp.Expression]]: + if not self._match_texts(keywords): + return None + if self._match(TokenType.L_PAREN, advance=False): + return self._parse_wrapped_csv(self._parse_expression) + + expression = self._parse_alias(self._parse_disjunction(), explicit=True) + return [expression] if expression else None + + def _parse_csv( + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + ) -> t.List[exp.Expression]: + parse_result = parse_method() + items = [parse_result] if parse_result is not None else [] + + while self._match(sep): + self._add_comments(parse_result) + parse_result = parse_method() + if parse_result is not None: + items.append(parse_result) + + return items + + def _parse_tokens( + self, parse_method: t.Callable, expressions: t.Dict + ) -> t.Optional[exp.Expression]: + this = parse_method() + + while self._match_set(expressions): + this = self.expression( + expressions[self._prev.token_type], + this=this, + comments=self._prev_comments, + expression=parse_method(), + ) + + return this + + def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]: + return self._parse_wrapped_csv(self._parse_id_var, optional=optional) + + def _parse_wrapped_csv( + self, + parse_method: t.Callable, + sep: TokenType = TokenType.COMMA, + optional: bool = False, + ) -> t.List[exp.Expression]: + return self._parse_wrapped( + lambda: self._parse_csv(parse_method, sep=sep), optional=optional + ) + + def _parse_wrapped(self, parse_method: t.Callable, optional: bool = False) -> t.Any: + wrapped = self._match(TokenType.L_PAREN) + if not wrapped and not optional: + self.raise_error("Expecting (") + parse_result = parse_method() + if wrapped: + self._match_r_paren() + return parse_result + + def _parse_expressions(self) -> t.List[exp.Expression]: + return self._parse_csv(self._parse_expression) + + def _parse_select_or_expression( + self, alias: bool = False + ) -> t.Optional[exp.Expression]: + return ( + self._parse_set_operations( + self._parse_alias(self._parse_assignment(), explicit=True) + if alias + else self._parse_assignment() + ) + or self._parse_select() + ) + + def _parse_ddl_select(self) -> t.Optional[exp.Expression]: + return self._parse_query_modifiers( + self._parse_set_operations( + self._parse_select(nested=True, parse_subquery_alias=False) + ) + ) + + def _parse_transaction(self) -> exp.Transaction | exp.Command: + this = None + if self._match_texts(self.TRANSACTION_KIND): + this = self._prev.text + + self._match_texts(("TRANSACTION", "WORK")) + + modes = [] + while True: + mode = [] + while self._match(TokenType.VAR) or self._match(TokenType.NOT): + mode.append(self._prev.text) + + if mode: + modes.append(" ".join(mode)) + if not self._match(TokenType.COMMA): + break + + return self.expression(exp.Transaction, this=this, modes=modes) + + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: + chain = None + savepoint = None + is_rollback = self._prev.token_type == TokenType.ROLLBACK + + self._match_texts(("TRANSACTION", "WORK")) + + if self._match_text_seq("TO"): + self._match_text_seq("SAVEPOINT") + savepoint = self._parse_id_var() + + if self._match(TokenType.AND): + chain = not self._match_text_seq("NO") + self._match_text_seq("CHAIN") + + if is_rollback: + return self.expression(exp.Rollback, savepoint=savepoint) + + return self.expression(exp.Commit, chain=chain) + + def _parse_refresh(self) -> exp.Refresh | exp.Command: + if self._match(TokenType.TABLE): + kind = "TABLE" + elif self._match_text_seq("MATERIALIZED", "VIEW"): + kind = "MATERIALIZED VIEW" + else: + kind = "" + + this = self._parse_string() or self._parse_table() + if not kind and not isinstance(this, exp.Literal): + return self._parse_as_command(self._prev) + + return self.expression(exp.Refresh, this=this, kind=kind) + + def _parse_column_def_with_exists(self): + start = self._index + self._match(TokenType.COLUMN) + + exists_column = self._parse_exists(not_=True) + expression = self._parse_field_def() + + if not isinstance(expression, exp.ColumnDef): + self._retreat(start) + return None + + expression.set("exists", exists_column) + + return expression + + def _parse_add_column(self) -> t.Optional[exp.ColumnDef]: + if not self._prev.text.upper() == "ADD": + return None + + expression = self._parse_column_def_with_exists() + if not expression: + return None + + # https://docs.databricks.com/delta/update-schema.html#explicitly-update-schema-to-add-columns + if self._match_texts(("FIRST", "AFTER")): + position = self._prev.text + column_position = self.expression( + exp.ColumnPosition, this=self._parse_column(), position=position + ) + expression.set("position", column_position) + + return expression + + def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]: + drop = self._match(TokenType.DROP) and self._parse_drop() + if drop and not isinstance(drop, exp.Command): + drop.set("kind", drop.args.get("kind", "COLUMN")) + return drop + + # https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html + def _parse_drop_partition( + self, exists: t.Optional[bool] = None + ) -> exp.DropPartition: + return self.expression( + exp.DropPartition, + expressions=self._parse_csv(self._parse_partition), + exists=exists, + ) + + def _parse_alter_table_add(self) -> t.List[exp.Expression]: + def _parse_add_alteration() -> t.Optional[exp.Expression]: + self._match_text_seq("ADD") + if self._match_set(self.ADD_CONSTRAINT_TOKENS, advance=False): + return self.expression( + exp.AddConstraint, + expressions=self._parse_csv(self._parse_constraint), + ) + + column_def = self._parse_add_column() + if isinstance(column_def, exp.ColumnDef): + return column_def + + exists = self._parse_exists(not_=True) + if self._match_pair(TokenType.PARTITION, TokenType.L_PAREN, advance=False): + return self.expression( + exp.AddPartition, + exists=exists, + this=self._parse_field(any_token=True), + location=self._match_text_seq("LOCATION", advance=False) + and self._parse_property(), + ) + + return None + + if not self._match_set(self.ADD_CONSTRAINT_TOKENS, advance=False) and ( + not self.dialect.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN + or self._match_text_seq("COLUMNS") + ): + schema = self._parse_schema() + + return ( + ensure_list(schema) + if schema + else self._parse_csv(self._parse_column_def_with_exists) + ) + + return self._parse_csv(_parse_add_alteration) + + def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]: + if self._match_texts(self.ALTER_ALTER_PARSERS): + return self.ALTER_ALTER_PARSERS[self._prev.text.upper()](self) + + # Many dialects support the ALTER [COLUMN] syntax, so if there is no + # keyword after ALTER we default to parsing this statement + self._match(TokenType.COLUMN) + column = self._parse_field(any_token=True) + + if self._match_pair(TokenType.DROP, TokenType.DEFAULT): + return self.expression(exp.AlterColumn, this=column, drop=True) + if self._match_pair(TokenType.SET, TokenType.DEFAULT): + return self.expression( + exp.AlterColumn, this=column, default=self._parse_disjunction() + ) + if self._match(TokenType.COMMENT): + return self.expression( + exp.AlterColumn, this=column, comment=self._parse_string() + ) + if self._match_text_seq("DROP", "NOT", "NULL"): + return self.expression( + exp.AlterColumn, + this=column, + drop=True, + allow_null=True, + ) + if self._match_text_seq("SET", "NOT", "NULL"): + return self.expression( + exp.AlterColumn, + this=column, + allow_null=False, + ) + + if self._match_text_seq("SET", "VISIBLE"): + return self.expression(exp.AlterColumn, this=column, visible="VISIBLE") + if self._match_text_seq("SET", "INVISIBLE"): + return self.expression(exp.AlterColumn, this=column, visible="INVISIBLE") + + self._match_text_seq("SET", "DATA") + self._match_text_seq("TYPE") + return self.expression( + exp.AlterColumn, + this=column, + dtype=self._parse_types(), + collate=self._match(TokenType.COLLATE) and self._parse_term(), + using=self._match(TokenType.USING) and self._parse_disjunction(), + ) + + def _parse_alter_diststyle(self) -> exp.AlterDistStyle: + if self._match_texts(("ALL", "EVEN", "AUTO")): + return self.expression( + exp.AlterDistStyle, this=exp.var(self._prev.text.upper()) + ) + + self._match_text_seq("KEY", "DISTKEY") + return self.expression(exp.AlterDistStyle, this=self._parse_column()) + + def _parse_alter_sortkey( + self, compound: t.Optional[bool] = None + ) -> exp.AlterSortKey: + if compound: + self._match_text_seq("SORTKEY") + + if self._match(TokenType.L_PAREN, advance=False): + return self.expression( + exp.AlterSortKey, + expressions=self._parse_wrapped_id_vars(), + compound=compound, + ) + + self._match_texts(("AUTO", "NONE")) + return self.expression( + exp.AlterSortKey, this=exp.var(self._prev.text.upper()), compound=compound + ) + + def _parse_alter_table_drop(self) -> t.List[exp.Expression]: + index = self._index - 1 + + partition_exists = self._parse_exists() + if self._match(TokenType.PARTITION, advance=False): + return self._parse_csv( + lambda: self._parse_drop_partition(exists=partition_exists) + ) + + self._retreat(index) + return self._parse_csv(self._parse_drop_column) + + def _parse_alter_table_rename( + self, + ) -> t.Optional[exp.AlterRename | exp.RenameColumn]: + if self._match(TokenType.COLUMN) or not self.ALTER_RENAME_REQUIRES_COLUMN: + exists = self._parse_exists() + old_column = self._parse_column() + to = self._match_text_seq("TO") + new_column = self._parse_column() + + if old_column is None or to is None or new_column is None: + return None + + return self.expression( + exp.RenameColumn, this=old_column, to=new_column, exists=exists + ) + + self._match_text_seq("TO") + return self.expression(exp.AlterRename, this=self._parse_table(schema=True)) + + def _parse_alter_table_set(self) -> exp.AlterSet: + alter_set = self.expression(exp.AlterSet) + + if self._match(TokenType.L_PAREN, advance=False) or self._match_text_seq( + "TABLE", "PROPERTIES" + ): + alter_set.set( + "expressions", self._parse_wrapped_csv(self._parse_assignment) + ) + elif self._match_text_seq("FILESTREAM_ON", advance=False): + alter_set.set("expressions", [self._parse_assignment()]) + elif self._match_texts(("LOGGED", "UNLOGGED")): + alter_set.set("option", exp.var(self._prev.text.upper())) + elif self._match_text_seq("WITHOUT") and self._match_texts(("CLUSTER", "OIDS")): + alter_set.set("option", exp.var(f"WITHOUT {self._prev.text.upper()}")) + elif self._match_text_seq("LOCATION"): + alter_set.set("location", self._parse_field()) + elif self._match_text_seq("ACCESS", "METHOD"): + alter_set.set("access_method", self._parse_field()) + elif self._match_text_seq("TABLESPACE"): + alter_set.set("tablespace", self._parse_field()) + elif self._match_text_seq("FILE", "FORMAT") or self._match_text_seq( + "FILEFORMAT" + ): + alter_set.set("file_format", [self._parse_field()]) + elif self._match_text_seq("STAGE_FILE_FORMAT"): + alter_set.set("file_format", self._parse_wrapped_options()) + elif self._match_text_seq("STAGE_COPY_OPTIONS"): + alter_set.set("copy_options", self._parse_wrapped_options()) + elif self._match_text_seq("TAG") or self._match_text_seq("TAGS"): + alter_set.set("tag", self._parse_csv(self._parse_assignment)) + else: + if self._match_text_seq("SERDE"): + alter_set.set("serde", self._parse_field()) + + properties = self._parse_wrapped(self._parse_properties, optional=True) + alter_set.set("expressions", [properties]) + + return alter_set + + def _parse_alter_session(self) -> exp.AlterSession: + """Parse ALTER SESSION SET/UNSET statements.""" + if self._match(TokenType.SET): + expressions = self._parse_csv(lambda: self._parse_set_item_assignment()) + return self.expression( + exp.AlterSession, expressions=expressions, unset=False + ) + + self._match_text_seq("UNSET") + expressions = self._parse_csv( + lambda: self.expression( + exp.SetItem, this=self._parse_id_var(any_token=True) + ) + ) + return self.expression(exp.AlterSession, expressions=expressions, unset=True) + + def _parse_alter(self) -> exp.Alter | exp.Command: + start = self._prev + + alter_token = self._match_set(self.ALTERABLES) and self._prev + if not alter_token: + return self._parse_as_command(start) + + exists = self._parse_exists() + only = self._match_text_seq("ONLY") + + if alter_token.token_type == TokenType.SESSION: + this = None + check = None + cluster = None + else: + this = self._parse_table( + schema=True, parse_partition=self.ALTER_TABLE_PARTITIONS + ) + check = self._match_text_seq("WITH", "CHECK") + cluster = self._parse_on_property() if self._match(TokenType.ON) else None + + if self._next: + self._advance() + + parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None + if parser: + actions = ensure_list(parser(self)) + not_valid = self._match_text_seq("NOT", "VALID") + options = self._parse_csv(self._parse_property) + cascade = ( + self.dialect.ALTER_TABLE_SUPPORTS_CASCADE + and self._match_text_seq("CASCADE") + ) + + if not self._curr and actions: + return self.expression( + exp.Alter, + this=this, + kind=alter_token.text.upper(), + exists=exists, + actions=actions, + only=only, + options=options, + cluster=cluster, + not_valid=not_valid, + check=check, + cascade=cascade, + ) + + return self._parse_as_command(start) + + def _parse_analyze(self) -> exp.Analyze | exp.Command: + start = self._prev + # https://duckdb.org/docs/sql/statements/analyze + if not self._curr: + return self.expression(exp.Analyze) + + options = [] + while self._match_texts(self.ANALYZE_STYLES): + if self._prev.text.upper() == "BUFFER_USAGE_LIMIT": + options.append(f"BUFFER_USAGE_LIMIT {self._parse_number()}") + else: + options.append(self._prev.text.upper()) + + this: t.Optional[exp.Expression] = None + inner_expression: t.Optional[exp.Expression] = None + + kind = self._curr and self._curr.text.upper() + + if self._match(TokenType.TABLE) or self._match(TokenType.INDEX): + this = self._parse_table_parts() + elif self._match_text_seq("TABLES"): + if self._match_set((TokenType.FROM, TokenType.IN)): + kind = f"{kind} {self._prev.text.upper()}" + this = self._parse_table(schema=True, is_db_reference=True) + elif self._match_text_seq("DATABASE"): + this = self._parse_table(schema=True, is_db_reference=True) + elif self._match_text_seq("CLUSTER"): + this = self._parse_table() + # Try matching inner expr keywords before fallback to parse table. + elif self._match_texts(self.ANALYZE_EXPRESSION_PARSERS): + kind = None + inner_expression = self.ANALYZE_EXPRESSION_PARSERS[self._prev.text.upper()]( + self + ) + else: + # Empty kind https://prestodb.io/docs/current/sql/analyze.html + kind = None + this = self._parse_table_parts() + + partition = self._try_parse(self._parse_partition) + if not partition and self._match_texts(self.PARTITION_KEYWORDS): + return self._parse_as_command(start) + + # https://docs.starrocks.io/docs/sql-reference/sql-statements/cbo_stats/ANALYZE_TABLE/ + if self._match_text_seq("WITH", "SYNC", "MODE") or self._match_text_seq( + "WITH", "ASYNC", "MODE" + ): + mode = f"WITH {self._tokens[self._index - 2].text.upper()} MODE" + else: + mode = None + + if self._match_texts(self.ANALYZE_EXPRESSION_PARSERS): + inner_expression = self.ANALYZE_EXPRESSION_PARSERS[self._prev.text.upper()]( + self + ) + + properties = self._parse_properties() + return self.expression( + exp.Analyze, + kind=kind, + this=this, + mode=mode, + partition=partition, + properties=properties, + expression=inner_expression, + options=options, + ) + + # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-aux-analyze-table.html + def _parse_analyze_statistics(self) -> exp.AnalyzeStatistics: + this = None + kind = self._prev.text.upper() + option = self._prev.text.upper() if self._match_text_seq("DELTA") else None + expressions = [] + + if not self._match_text_seq("STATISTICS"): + self.raise_error("Expecting token STATISTICS") + + if self._match_text_seq("NOSCAN"): + this = "NOSCAN" + elif self._match(TokenType.FOR): + if self._match_text_seq("ALL", "COLUMNS"): + this = "FOR ALL COLUMNS" + if self._match_texts("COLUMNS"): + this = "FOR COLUMNS" + expressions = self._parse_csv(self._parse_column_reference) + elif self._match_text_seq("SAMPLE"): + sample = self._parse_number() + expressions = [ + self.expression( + exp.AnalyzeSample, + sample=sample, + kind=self._prev.text.upper() + if self._match(TokenType.PERCENT) + else None, + ) + ] + + return self.expression( + exp.AnalyzeStatistics, + kind=kind, + option=option, + this=this, + expressions=expressions, + ) + + # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ANALYZE.html + def _parse_analyze_validate(self) -> exp.AnalyzeValidate: + kind = None + this = None + expression: t.Optional[exp.Expression] = None + if self._match_text_seq("REF", "UPDATE"): + kind = "REF" + this = "UPDATE" + if self._match_text_seq("SET", "DANGLING", "TO", "NULL"): + this = "UPDATE SET DANGLING TO NULL" + elif self._match_text_seq("STRUCTURE"): + kind = "STRUCTURE" + if self._match_text_seq("CASCADE", "FAST"): + this = "CASCADE FAST" + elif self._match_text_seq("CASCADE", "COMPLETE") and self._match_texts( + ("ONLINE", "OFFLINE") + ): + this = f"CASCADE COMPLETE {self._prev.text.upper()}" + expression = self._parse_into() + + return self.expression( + exp.AnalyzeValidate, kind=kind, this=this, expression=expression + ) + + def _parse_analyze_columns(self) -> t.Optional[exp.AnalyzeColumns]: + this = self._prev.text.upper() + if self._match_text_seq("COLUMNS"): + return self.expression( + exp.AnalyzeColumns, this=f"{this} {self._prev.text.upper()}" + ) + return None + + def _parse_analyze_delete(self) -> t.Optional[exp.AnalyzeDelete]: + kind = self._prev.text.upper() if self._match_text_seq("SYSTEM") else None + if self._match_text_seq("STATISTICS"): + return self.expression(exp.AnalyzeDelete, kind=kind) + return None + + def _parse_analyze_list(self) -> t.Optional[exp.AnalyzeListChainedRows]: + if self._match_text_seq("CHAINED", "ROWS"): + return self.expression( + exp.AnalyzeListChainedRows, expression=self._parse_into() + ) + return None + + # https://dev.mysql.com/doc/refman/8.4/en/analyze-table.html + def _parse_analyze_histogram(self) -> exp.AnalyzeHistogram: + this = self._prev.text.upper() + expression: t.Optional[exp.Expression] = None + expressions = [] + update_options = None + + if self._match_text_seq("HISTOGRAM", "ON"): + expressions = self._parse_csv(self._parse_column_reference) + with_expressions = [] + while self._match(TokenType.WITH): + # https://docs.starrocks.io/docs/sql-reference/sql-statements/cbo_stats/ANALYZE_TABLE/ + if self._match_texts(("SYNC", "ASYNC")): + if self._match_text_seq("MODE", advance=False): + with_expressions.append(f"{self._prev.text.upper()} MODE") + self._advance() + else: + buckets = self._parse_number() + if self._match_text_seq("BUCKETS"): + with_expressions.append(f"{buckets} BUCKETS") + if with_expressions: + expression = self.expression( + exp.AnalyzeWith, expressions=with_expressions + ) + + if self._match_texts(("MANUAL", "AUTO")) and self._match( + TokenType.UPDATE, advance=False + ): + update_options = self._prev.text.upper() + self._advance() + elif self._match_text_seq("USING", "DATA"): + expression = self.expression(exp.UsingData, this=self._parse_string()) + + return self.expression( + exp.AnalyzeHistogram, + this=this, + expressions=expressions, + expression=expression, + update_options=update_options, + ) + + def _parse_merge(self) -> exp.Merge: + self._match(TokenType.INTO) + target = self._parse_table() + + if target and self._match(TokenType.ALIAS, advance=False): + target.set("alias", self._parse_table_alias()) + + self._match(TokenType.USING) + using = self._parse_table() + + return self.expression( + exp.Merge, + this=target, + using=using, + on=self._match(TokenType.ON) and self._parse_disjunction(), + using_cond=self._match(TokenType.USING) and self._parse_using_identifiers(), + whens=self._parse_when_matched(), + returning=self._parse_returning(), + ) + + def _parse_when_matched(self) -> exp.Whens: + whens = [] + + while self._match(TokenType.WHEN): + matched = not self._match(TokenType.NOT) + self._match_text_seq("MATCHED") + source = ( + False + if self._match_text_seq("BY", "TARGET") + else self._match_text_seq("BY", "SOURCE") + ) + condition = ( + self._parse_disjunction() if self._match(TokenType.AND) else None + ) + + self._match(TokenType.THEN) + + if self._match(TokenType.INSERT): + this = self._parse_star() + if this: + then: t.Optional[exp.Expression] = self.expression( + exp.Insert, this=this + ) + else: + then = self.expression( + exp.Insert, + this=exp.var("ROW") + if self._match_text_seq("ROW") + else self._parse_value(values=False), + expression=self._match_text_seq("VALUES") + and self._parse_value(), + ) + elif self._match(TokenType.UPDATE): + expressions = self._parse_star() + if expressions: + then = self.expression(exp.Update, expressions=expressions) + else: + then = self.expression( + exp.Update, + expressions=self._match(TokenType.SET) + and self._parse_csv(self._parse_equality), + ) + elif self._match(TokenType.DELETE): + then = self.expression(exp.Var, this=self._prev.text) + else: + then = self._parse_var_from_options(self.CONFLICT_ACTIONS) + + whens.append( + self.expression( + exp.When, + matched=matched, + source=source, + condition=condition, + then=then, + ) + ) + return self.expression(exp.Whens, expressions=whens) + + def _parse_show(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE) + if parser: + return parser(self) + return self._parse_as_command(self._prev) + + def _parse_set_item_assignment( + self, kind: t.Optional[str] = None + ) -> t.Optional[exp.Expression]: + index = self._index + + if kind in ("GLOBAL", "SESSION") and self._match_text_seq("TRANSACTION"): + return self._parse_set_transaction(global_=kind == "GLOBAL") + + left = self._parse_primary() or self._parse_column() + assignment_delimiter = self._match_texts(self.SET_ASSIGNMENT_DELIMITERS) + + if not left or ( + self.SET_REQUIRES_ASSIGNMENT_DELIMITER and not assignment_delimiter + ): + self._retreat(index) + return None + + right = self._parse_statement() or self._parse_id_var() + if isinstance(right, (exp.Column, exp.Identifier)): + right = exp.var(right.name) + + this = self.expression(exp.EQ, this=left, expression=right) + return self.expression(exp.SetItem, this=this, kind=kind) + + def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: + self._match_text_seq("TRANSACTION") + characteristics = self._parse_csv( + lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS) + ) + return self.expression( + exp.SetItem, + expressions=characteristics, + kind="TRANSACTION", + global_=global_, + ) + + def _parse_set_item(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE) + return parser(self) if parser else self._parse_set_item_assignment(kind=None) + + def _parse_set( + self, unset: bool = False, tag: bool = False + ) -> exp.Set | exp.Command: + index = self._index + set_ = self.expression( + exp.Set, + expressions=self._parse_csv(self._parse_set_item), + unset=unset, + tag=tag, + ) + + if self._curr: + self._retreat(index) + return self._parse_as_command(self._prev) + + return set_ + + def _parse_var_from_options( + self, options: OPTIONS_TYPE, raise_unmatched: bool = True + ) -> t.Optional[exp.Var]: + start = self._curr + if not start: + return None + + option = start.text.upper() + continuations = options.get(option) + + index = self._index + self._advance() + for keywords in continuations or []: + if isinstance(keywords, str): + keywords = (keywords,) + + if self._match_text_seq(*keywords): + option = f"{option} {' '.join(keywords)}" + break + else: + if continuations or continuations is None: + if raise_unmatched: + self.raise_error(f"Unknown option {option}") + + self._retreat(index) + return None + + return exp.var(option) + + def _parse_as_command(self, start: Token) -> exp.Command: + while self._curr: + self._advance() + text = self._find_sql(start, self._prev) + size = len(start.text) + self._warn_unsupported() + return exp.Command(this=text[:size], expression=text[size:]) + + def _parse_dict_property(self, this: str) -> exp.DictProperty: + settings = [] + + self._match_l_paren() + kind = self._parse_id_var() + + if self._match(TokenType.L_PAREN): + while True: + key = self._parse_id_var() + value = self._parse_primary() + if not key and value is None: + break + settings.append( + self.expression(exp.DictSubProperty, this=key, value=value) + ) + self._match(TokenType.R_PAREN) + + self._match_r_paren() + + return self.expression( + exp.DictProperty, + this=this, + kind=kind.this if kind else None, + settings=settings, + ) + + def _parse_dict_range(self, this: str) -> exp.DictRange: + self._match_l_paren() + has_min = self._match_text_seq("MIN") + if has_min: + min = self._parse_var() or self._parse_primary() + self._match_text_seq("MAX") + max = self._parse_var() or self._parse_primary() + else: + max = self._parse_var() or self._parse_primary() + min = exp.Literal.number(0) + self._match_r_paren() + return self.expression(exp.DictRange, this=this, min=min, max=max) + + def _parse_comprehension( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Comprehension]: + index = self._index + expression = self._parse_column() + position = self._match(TokenType.COMMA) and self._parse_column() + + if not self._match(TokenType.IN): + self._retreat(index - 1) + return None + iterator = self._parse_column() + condition = self._parse_disjunction() if self._match_text_seq("IF") else None + return self.expression( + exp.Comprehension, + this=this, + expression=expression, + position=position, + iterator=iterator, + condition=condition, + ) + + def _parse_heredoc(self) -> t.Optional[exp.Heredoc]: + if self._match(TokenType.HEREDOC_STRING): + return self.expression(exp.Heredoc, this=self._prev.text) + + if not self._match_text_seq("$"): + return None + + tags = ["$"] + tag_text = None + + if self._is_connected(): + self._advance() + tags.append(self._prev.text.upper()) + else: + self.raise_error("No closing $ found") + + if tags[-1] != "$": + if self._is_connected() and self._match_text_seq("$"): + tag_text = tags[-1] + tags.append("$") + else: + self.raise_error("No closing $ found") + + heredoc_start = self._curr + + while self._curr: + if self._match_text_seq(*tags, advance=False): + this = self._find_sql(heredoc_start, self._prev) + self._advance(len(tags)) + return self.expression(exp.Heredoc, this=this, tag=tag_text) + + self._advance() + + self.raise_error(f"No closing {''.join(tags)} found") + return None + + def _find_parser( + self, parsers: t.Dict[str, t.Callable], trie: t.Dict + ) -> t.Optional[t.Callable]: + if not self._curr: + return None + + index = self._index + this = [] + while True: + # The current token might be multiple words + curr = self._curr.text.upper() + key = curr.split(" ") + this.append(curr) + + self._advance() + result, trie = in_trie(trie, key) + if result == TrieResult.FAILED: + break + + if result == TrieResult.EXISTS: + subparser = parsers[" ".join(this)] + return subparser + + self._retreat(index) + return None + + def _match(self, token_type, advance=True, expression=None): + if not self._curr: + return None + + if self._curr.token_type == token_type: + if advance: + self._advance() + self._add_comments(expression) + return True + + return None + + def _match_set(self, types, advance=True): + if not self._curr: + return None + + if self._curr.token_type in types: + if advance: + self._advance() + return True + + return None + + def _match_pair(self, token_type_a, token_type_b, advance=True): + if not self._curr or not self._next: + return None + + if ( + self._curr.token_type == token_type_a + and self._next.token_type == token_type_b + ): + if advance: + self._advance(2) + return True + + return None + + def _match_l_paren(self, expression: t.Optional[exp.Expression] = None) -> None: + if not self._match(TokenType.L_PAREN, expression=expression): + self.raise_error("Expecting (") + + def _match_r_paren(self, expression: t.Optional[exp.Expression] = None) -> None: + if not self._match(TokenType.R_PAREN, expression=expression): + self.raise_error("Expecting )") + + def _match_texts(self, texts, advance=True): + if ( + self._curr + and self._curr.token_type != TokenType.STRING + and self._curr.text.upper() in texts + ): + if advance: + self._advance() + return True + return None + + def _match_text_seq(self, *texts, advance=True): + index = self._index + for text in texts: + if ( + self._curr + and self._curr.token_type != TokenType.STRING + and self._curr.text.upper() == text + ): + self._advance() + else: + self._retreat(index) + return None + + if not advance: + self._retreat(index) + + return True + + def _replace_lambda( + self, node: t.Optional[exp.Expression], expressions: t.List[exp.Expression] + ) -> t.Optional[exp.Expression]: + if not node: + return node + + lambda_types = {e.name: e.args.get("to") or False for e in expressions} + + for column in node.find_all(exp.Column): + typ = lambda_types.get(column.parts[0].name) + if typ is not None: + dot_or_id = column.to_dot() if column.table else column.this + + if typ: + dot_or_id = self.expression( + exp.Cast, + this=dot_or_id, + to=typ, + ) + + parent = column.parent + + while isinstance(parent, exp.Dot): + if not isinstance(parent.parent, exp.Dot): + parent.replace(dot_or_id) + break + parent = parent.parent + else: + if column is node: + node = dot_or_id + else: + column.replace(dot_or_id) + return node + + def _parse_truncate_table(self) -> t.Optional[exp.TruncateTable] | exp.Expression: + start = self._prev + + # Not to be confused with TRUNCATE(number, decimals) function call + if self._match(TokenType.L_PAREN): + self._retreat(self._index - 2) + return self._parse_function() + + # Clickhouse supports TRUNCATE DATABASE as well + is_database = self._match(TokenType.DATABASE) + + self._match(TokenType.TABLE) + + exists = self._parse_exists(not_=False) + + expressions = self._parse_csv( + lambda: self._parse_table(schema=True, is_db_reference=is_database) + ) + + cluster = self._parse_on_property() if self._match(TokenType.ON) else None + + if self._match_text_seq("RESTART", "IDENTITY"): + identity = "RESTART" + elif self._match_text_seq("CONTINUE", "IDENTITY"): + identity = "CONTINUE" + else: + identity = None + + if self._match_text_seq("CASCADE") or self._match_text_seq("RESTRICT"): + option = self._prev.text + else: + option = None + + partition = self._parse_partition() + + # Fallback case + if self._curr: + return self._parse_as_command(start) + + return self.expression( + exp.TruncateTable, + expressions=expressions, + is_database=is_database, + exists=exists, + cluster=cluster, + identity=identity, + option=option, + partition=partition, + ) + + def _parse_with_operator(self) -> t.Optional[exp.Expression]: + this = self._parse_ordered(self._parse_opclass) + + if not self._match(TokenType.WITH): + return this + + op = self._parse_var(any_token=True) + + return self.expression(exp.WithOperator, this=this, op=op) + + def _parse_wrapped_options(self) -> t.List[t.Optional[exp.Expression]]: + self._match(TokenType.EQ) + self._match(TokenType.L_PAREN) + + opts: t.List[t.Optional[exp.Expression]] = [] + option: exp.Expression | None + while self._curr and not self._match(TokenType.R_PAREN): + if self._match_text_seq("FORMAT_NAME", "="): + # The FORMAT_NAME can be set to an identifier for Snowflake and T-SQL + option = self._parse_format_name() + else: + option = self._parse_property() + + if option is None: + self.raise_error("Unable to parse option") + break + + opts.append(option) + + return opts + + def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]: + sep = TokenType.COMMA if self.dialect.COPY_PARAMS_ARE_CSV else None + + options = [] + while self._curr and not self._match(TokenType.R_PAREN, advance=False): + option = self._parse_var(any_token=True) + prev = self._prev.text.upper() + + # Different dialects might separate options and values by white space, "=" and "AS" + self._match(TokenType.EQ) + self._match(TokenType.ALIAS) + + param = self.expression(exp.CopyParameter, this=option) + + if prev in self.COPY_INTO_VARLEN_OPTIONS and self._match( + TokenType.L_PAREN, advance=False + ): + # Snowflake FILE_FORMAT case, Databricks COPY & FORMAT options + param.set("expressions", self._parse_wrapped_options()) + elif prev == "FILE_FORMAT": + # T-SQL's external file format case + param.set("expression", self._parse_field()) + elif ( + prev == "FORMAT" + and self._prev.token_type == TokenType.ALIAS + and self._match_texts(("AVRO", "JSON")) + ): + param.set("this", exp.var(f"FORMAT AS {self._prev.text.upper()}")) + param.set("expression", self._parse_field()) + else: + param.set( + "expression", self._parse_unquoted_field() or self._parse_bracket() + ) + + options.append(param) + self._match(sep) + + return options + + def _parse_credentials(self) -> t.Optional[exp.Credentials]: + expr = self.expression(exp.Credentials) + + if self._match_text_seq("STORAGE_INTEGRATION", "="): + expr.set("storage", self._parse_field()) + if self._match_text_seq("CREDENTIALS"): + # Snowflake case: CREDENTIALS = (...), Redshift case: CREDENTIALS + creds = ( + self._parse_wrapped_options() + if self._match(TokenType.EQ) + else self._parse_field() + ) + expr.set("credentials", creds) + if self._match_text_seq("ENCRYPTION"): + expr.set("encryption", self._parse_wrapped_options()) + if self._match_text_seq("IAM_ROLE"): + expr.set( + "iam_role", + exp.var(self._prev.text) + if self._match(TokenType.DEFAULT) + else self._parse_field(), + ) + if self._match_text_seq("REGION"): + expr.set("region", self._parse_field()) + + return expr + + def _parse_file_location(self) -> t.Optional[exp.Expression]: + return self._parse_field() + + def _parse_copy(self) -> exp.Copy | exp.Command: + start = self._prev + + self._match(TokenType.INTO) + + this = ( + self._parse_select(nested=True, parse_subquery_alias=False) + if self._match(TokenType.L_PAREN, advance=False) + else self._parse_table(schema=True) + ) + + kind = self._match(TokenType.FROM) or not self._match_text_seq("TO") + + files = self._parse_csv(self._parse_file_location) + if self._match(TokenType.EQ, advance=False): + # Backtrack one token since we've consumed the lhs of a parameter assignment here. + # This can happen for Snowflake dialect. Instead, we'd like to parse the parameter + # list via `_parse_wrapped(..)` below. + self._advance(-1) + files = [] + + credentials = self._parse_credentials() + + self._match_text_seq("WITH") + + params = self._parse_wrapped(self._parse_copy_parameters, optional=True) + + # Fallback case + if self._curr: + return self._parse_as_command(start) + + return self.expression( + exp.Copy, + this=this, + kind=kind, + credentials=credentials, + files=files, + params=params, + ) + + def _parse_normalize(self) -> exp.Normalize: + return self.expression( + exp.Normalize, + this=self._parse_bitwise(), + form=self._match(TokenType.COMMA) and self._parse_var(), + ) + + def _parse_ceil_floor(self, expr_type: t.Type[TCeilFloor]) -> TCeilFloor: + args = self._parse_csv(lambda: self._parse_lambda()) + + this = seq_get(args, 0) + decimals = seq_get(args, 1) + + return expr_type( + this=this, + decimals=decimals, + to=self._match_text_seq("TO") and self._parse_var(), + ) + + def _parse_star_ops(self) -> t.Optional[exp.Expression]: + star_token = self._prev + + if self._match_text_seq("COLUMNS", "(", advance=False): + this = self._parse_function() + if isinstance(this, exp.Columns): + this.set("unpack", True) + return this + + return self.expression( + exp.Star, + except_=self._parse_star_op("EXCEPT", "EXCLUDE"), + replace=self._parse_star_op("REPLACE"), + rename=self._parse_star_op("RENAME"), + ).update_positions(star_token) + + def _parse_grant_privilege(self) -> t.Optional[exp.GrantPrivilege]: + privilege_parts = [] + + # Keep consuming consecutive keywords until comma (end of this privilege) or ON + # (end of privilege list) or L_PAREN (start of column list) are met + while self._curr and not self._match_set( + self.PRIVILEGE_FOLLOW_TOKENS, advance=False + ): + privilege_parts.append(self._curr.text.upper()) + self._advance() + + this = exp.var(" ".join(privilege_parts)) + expressions = ( + self._parse_wrapped_csv(self._parse_column) + if self._match(TokenType.L_PAREN, advance=False) + else None + ) + + return self.expression(exp.GrantPrivilege, this=this, expressions=expressions) + + def _parse_grant_principal(self) -> t.Optional[exp.GrantPrincipal]: + kind = self._match_texts(("ROLE", "GROUP")) and self._prev.text.upper() + principal = self._parse_id_var() + + if not principal: + return None + + return self.expression(exp.GrantPrincipal, this=principal, kind=kind) + + def _parse_grant_revoke_common( + self, + ) -> t.Tuple[t.Optional[t.List], t.Optional[str], t.Optional[exp.Expression]]: + privileges = self._parse_csv(self._parse_grant_privilege) + + self._match(TokenType.ON) + kind = self._match_set(self.CREATABLES) and self._prev.text.upper() + + # Attempt to parse the securable e.g. MySQL allows names + # such as "foo.*", "*.*" which are not easily parseable yet + securable = self._try_parse(self._parse_table_parts) + + return privileges, kind, securable + + def _parse_grant(self) -> exp.Grant | exp.Command: + start = self._prev + + privileges, kind, securable = self._parse_grant_revoke_common() + + if not securable or not self._match_text_seq("TO"): + return self._parse_as_command(start) + + principals = self._parse_csv(self._parse_grant_principal) + + grant_option = self._match_text_seq("WITH", "GRANT", "OPTION") + + if self._curr: + return self._parse_as_command(start) + + return self.expression( + exp.Grant, + privileges=privileges, + kind=kind, + securable=securable, + principals=principals, + grant_option=grant_option, + ) + + def _parse_revoke(self) -> exp.Revoke | exp.Command: + start = self._prev + + grant_option = self._match_text_seq("GRANT", "OPTION", "FOR") + + privileges, kind, securable = self._parse_grant_revoke_common() + + if not securable or not self._match_text_seq("FROM"): + return self._parse_as_command(start) + + principals = self._parse_csv(self._parse_grant_principal) + + cascade = None + if self._match_texts(("CASCADE", "RESTRICT")): + cascade = self._prev.text.upper() + + if self._curr: + return self._parse_as_command(start) + + return self.expression( + exp.Revoke, + privileges=privileges, + kind=kind, + securable=securable, + principals=principals, + grant_option=grant_option, + cascade=cascade, + ) + + def _parse_overlay(self) -> exp.Overlay: + def _parse_overlay_arg(text: str) -> t.Optional[exp.Expression]: + return ( + self._match(TokenType.COMMA) or self._match_text_seq(text) + ) and self._parse_bitwise() + + return self.expression( + exp.Overlay, + this=self._parse_bitwise(), + expression=_parse_overlay_arg("PLACING"), + from_=_parse_overlay_arg("FROM"), + for_=_parse_overlay_arg("FOR"), + ) + + def _parse_format_name(self) -> exp.Property: + # Note: Although not specified in the docs, Snowflake does accept a string/identifier + # for FILE_FORMAT = + return self.expression( + exp.Property, + this=exp.var("FORMAT_NAME"), + value=self._parse_string() or self._parse_table_parts(), + ) + + def _parse_max_min_by(self, expr_type: t.Type[exp.AggFunc]) -> exp.AggFunc: + args: t.List[exp.Expression] = [] + + if self._match(TokenType.DISTINCT): + args.append( + self.expression(exp.Distinct, expressions=[self._parse_lambda()]) + ) + self._match(TokenType.COMMA) + + args.extend(self._parse_function_args()) + + return self.expression( + expr_type, + this=seq_get(args, 0), + expression=seq_get(args, 1), + count=seq_get(args, 2), + ) + + def _identifier_expression( + self, token: t.Optional[Token] = None, **kwargs: t.Any + ) -> exp.Identifier: + return self.expression(exp.Identifier, token=token or self._prev, **kwargs) + + def _build_pipe_cte( + self, + query: exp.Query, + expressions: t.List[exp.Expression], + alias_cte: t.Optional[exp.TableAlias] = None, + ) -> exp.Select: + new_cte: t.Optional[t.Union[str, exp.TableAlias]] + if alias_cte: + new_cte = alias_cte + else: + self._pipe_cte_counter += 1 + new_cte = f"__tmp{self._pipe_cte_counter}" + + with_ = query.args.get("with_") + ctes = with_.pop() if with_ else None + + new_select = exp.select(*expressions, copy=False).from_(new_cte, copy=False) + if ctes: + new_select.set("with_", ctes) + + return new_select.with_(new_cte, as_=query, copy=False) + + def _parse_pipe_syntax_select(self, query: exp.Select) -> exp.Select: + select = self._parse_select(consume_pipe=False) + if not select: + return query + + return self._build_pipe_cte( + query=query.select(*select.expressions, append=False), + expressions=[exp.Star()], + ) + + def _parse_pipe_syntax_limit(self, query: exp.Select) -> exp.Select: + limit = self._parse_limit() + offset = self._parse_offset() + if limit: + curr_limit = query.args.get("limit", limit) + if curr_limit.expression.to_py() >= limit.expression.to_py(): + query.limit(limit, copy=False) + if offset: + curr_offset = query.args.get("offset") + curr_offset = curr_offset.expression.to_py() if curr_offset else 0 + query.offset( + exp.Literal.number(curr_offset + offset.expression.to_py()), copy=False + ) + + return query + + def _parse_pipe_syntax_aggregate_fields(self) -> t.Optional[exp.Expression]: + this = self._parse_disjunction() + if self._match_text_seq("GROUP", "AND", advance=False): + return this + + this = self._parse_alias(this) + + if self._match_set((TokenType.ASC, TokenType.DESC), advance=False): + return self._parse_ordered(lambda: this) + + return this + + def _parse_pipe_syntax_aggregate_group_order_by( + self, query: exp.Select, group_by_exists: bool = True + ) -> exp.Select: + expr = self._parse_csv(self._parse_pipe_syntax_aggregate_fields) + aggregates_or_groups, orders = [], [] + for element in expr: + if isinstance(element, exp.Ordered): + this = element.this + if isinstance(this, exp.Alias): + element.set("this", this.args["alias"]) + orders.append(element) + else: + this = element + aggregates_or_groups.append(this) + + if group_by_exists: + query.select(*aggregates_or_groups, copy=False).group_by( + *[ + projection.args.get("alias", projection) + for projection in aggregates_or_groups + ], + copy=False, + ) + else: + query.select(*aggregates_or_groups, append=False, copy=False) + + if orders: + return query.order_by(*orders, append=False, copy=False) + + return query + + def _parse_pipe_syntax_aggregate(self, query: exp.Select) -> exp.Select: + self._match_text_seq("AGGREGATE") + query = self._parse_pipe_syntax_aggregate_group_order_by( + query, group_by_exists=False + ) + + if self._match(TokenType.GROUP_BY) or ( + self._match_text_seq("GROUP", "AND") and self._match(TokenType.ORDER_BY) + ): + query = self._parse_pipe_syntax_aggregate_group_order_by(query) + + return self._build_pipe_cte(query=query, expressions=[exp.Star()]) + + def _parse_pipe_syntax_set_operator( + self, query: exp.Query + ) -> t.Optional[exp.Query]: + first_setop = self.parse_set_operation(this=query) + if not first_setop: + return None + + def _parse_and_unwrap_query() -> t.Optional[exp.Select]: + expr = self._parse_paren() + return expr.assert_is(exp.Subquery).unnest() if expr else None + + first_setop.this.pop() + + setops = [ + first_setop.expression.pop().assert_is(exp.Subquery).unnest(), + *self._parse_csv(_parse_and_unwrap_query), + ] + + query = self._build_pipe_cte(query=query, expressions=[exp.Star()]) + with_ = query.args.get("with_") + ctes = with_.pop() if with_ else None + + if isinstance(first_setop, exp.Union): + query = query.union(*setops, copy=False, **first_setop.args) + elif isinstance(first_setop, exp.Except): + query = query.except_(*setops, copy=False, **first_setop.args) + else: + query = query.intersect(*setops, copy=False, **first_setop.args) + + query.set("with_", ctes) + + return self._build_pipe_cte(query=query, expressions=[exp.Star()]) + + def _parse_pipe_syntax_join(self, query: exp.Query) -> t.Optional[exp.Query]: + join = self._parse_join() + if not join: + return None + + if isinstance(query, exp.Select): + return query.join(join, copy=False) + + return query + + def _parse_pipe_syntax_pivot(self, query: exp.Select) -> exp.Select: + pivots = self._parse_pivots() + if not pivots: + return query + + from_ = query.args.get("from_") + if from_: + from_.this.set("pivots", pivots) + + return self._build_pipe_cte(query=query, expressions=[exp.Star()]) + + def _parse_pipe_syntax_extend(self, query: exp.Select) -> exp.Select: + self._match_text_seq("EXTEND") + query.select( + *[exp.Star(), *self._parse_expressions()], append=False, copy=False + ) + return self._build_pipe_cte(query=query, expressions=[exp.Star()]) + + def _parse_pipe_syntax_tablesample(self, query: exp.Select) -> exp.Select: + sample = self._parse_table_sample() + + with_ = query.args.get("with_") + if with_: + with_.expressions[-1].this.set("sample", sample) + else: + query.set("sample", sample) + + return query + + def _parse_pipe_syntax_query(self, query: exp.Query) -> t.Optional[exp.Query]: + if isinstance(query, exp.Subquery): + query = exp.select("*").from_(query, copy=False) + + if not query.args.get("from_"): + query = exp.select("*").from_(query.subquery(copy=False), copy=False) + + while self._match(TokenType.PIPE_GT): + start = self._curr + parser = self.PIPE_SYNTAX_TRANSFORM_PARSERS.get(self._curr.text.upper()) + if not parser: + # The set operators (UNION, etc) and the JOIN operator have a few common starting + # keywords, making it tricky to disambiguate them without lookahead. The approach + # here is to try and parse a set operation and if that fails, then try to parse a + # join operator. If that fails as well, then the operator is not supported. + parsed_query = self._parse_pipe_syntax_set_operator(query) + parsed_query = parsed_query or self._parse_pipe_syntax_join(query) + if not parsed_query: + self._retreat(start) + self.raise_error( + f"Unsupported pipe syntax operator: '{start.text.upper()}'." + ) + break + query = parsed_query + else: + query = parser(self, query) + + return query + + def _parse_declareitem(self) -> t.Optional[exp.DeclareItem]: + vars = self._parse_csv(self._parse_id_var) + if not vars: + return None + + return self.expression( + exp.DeclareItem, + this=vars, + kind=self._parse_types(), + default=self._match(TokenType.DEFAULT) and self._parse_bitwise(), + ) + + def _parse_declare(self) -> exp.Declare | exp.Command: + start = self._prev + expressions = self._try_parse(lambda: self._parse_csv(self._parse_declareitem)) + + if not expressions or self._curr: + return self._parse_as_command(start) + + return self.expression(exp.Declare, expressions=expressions) + + def build_cast(self, strict: bool, **kwargs) -> exp.Cast: + exp_class = exp.Cast if strict else exp.TryCast + + if exp_class == exp.TryCast: + kwargs["requires_string"] = self.dialect.TRY_CAST_REQUIRES_STRING + + return self.expression(exp_class, **kwargs) + + def _parse_json_value(self) -> exp.JSONValue: + this = self._parse_bitwise() + self._match(TokenType.COMMA) + path = self._parse_bitwise() + + returning = self._match(TokenType.RETURNING) and self._parse_type() + + return self.expression( + exp.JSONValue, + this=this, + path=self.dialect.to_json_path(path), + returning=returning, + on_condition=self._parse_on_condition(), + ) + + def _parse_group_concat(self) -> t.Optional[exp.Expression]: + def concat_exprs( + node: t.Optional[exp.Expression], exprs: t.List[exp.Expression] + ) -> exp.Expression: + if isinstance(node, exp.Distinct) and len(node.expressions) > 1: + concat_exprs = [ + self.expression( + exp.Concat, + expressions=node.expressions, + safe=True, + coalesce=self.dialect.CONCAT_COALESCE, + ) + ] + node.set("expressions", concat_exprs) + return node + if len(exprs) == 1: + return exprs[0] + return self.expression( + exp.Concat, + expressions=args, + safe=True, + coalesce=self.dialect.CONCAT_COALESCE, + ) + + args = self._parse_csv(self._parse_lambda) + + if args: + order = args[-1] if isinstance(args[-1], exp.Order) else None + + if order: + # Order By is the last (or only) expression in the list and has consumed the 'expr' before it, + # remove 'expr' from exp.Order and add it back to args + args[-1] = order.this + order.set("this", concat_exprs(order.this, args)) + + this = order or concat_exprs(args[0], args) + else: + this = None + + separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None + + return self.expression(exp.GroupConcat, this=this, separator=separator) + + def _parse_initcap(self) -> exp.Initcap: + expr = exp.Initcap.from_arg_list(self._parse_function_args()) + + # attach dialect's default delimiters + if expr.args.get("expression") is None: + expr.set( + "expression", + exp.Literal.string(self.dialect.INITCAP_DEFAULT_DELIMITER_CHARS), + ) + + return expr + + def _parse_operator( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + while True: + if not self._match(TokenType.L_PAREN): + break + + op = "" + while self._curr and not self._match(TokenType.R_PAREN): + op += self._curr.text + self._advance() + + this = self.expression( + exp.Operator, + comments=self._prev_comments, + this=this, + operator=op, + expression=self._parse_bitwise(), + ) + + if not self._match(TokenType.OPERATOR): + break + + return this diff --git a/third_party/bigframes_vendored/sqlglot/planner.py b/third_party/bigframes_vendored/sqlglot/planner.py new file mode 100644 index 0000000000..58373fca3c --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/planner.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import math +import typing as t + +from bigframes_vendored.sqlglot import alias, exp +from bigframes_vendored.sqlglot.helper import name_sequence +from bigframes_vendored.sqlglot.optimizer.eliminate_joins import join_condition + + +class Plan: + def __init__(self, expression: exp.Expression) -> None: + self.expression = expression.copy() + self.root = Step.from_expression(self.expression) + self._dag: t.Dict[Step, t.Set[Step]] = {} + + @property + def dag(self) -> t.Dict[Step, t.Set[Step]]: + if not self._dag: + dag: t.Dict[Step, t.Set[Step]] = {} + nodes = {self.root} + + while nodes: + node = nodes.pop() + dag[node] = set() + + for dep in node.dependencies: + dag[node].add(dep) + nodes.add(dep) + + self._dag = dag + + return self._dag + + @property + def leaves(self) -> t.Iterator[Step]: + return (node for node, deps in self.dag.items() if not deps) + + def __repr__(self) -> str: + return f"Plan\n----\n{repr(self.root)}" + + +class Step: + @classmethod + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: + """ + Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. + Note: the expression's tables and subqueries must be aliased for this method to work. For + example, given the following expression: + + SELECT + x.a, + SUM(x.b) + FROM x AS x + JOIN y AS y + ON x.a = y.a + GROUP BY x.a + + the following DAG is produced (the expression IDs might differ per execution): + + - Aggregate: x (4347984624) + Context: + Aggregations: + - SUM(x.b) + Group: + - x.a + Projections: + - x.a + - "x"."" + Dependencies: + - Join: x (4347985296) + Context: + y: + On: x.a = y.a + Projections: + Dependencies: + - Scan: x (4347983136) + Context: + Source: x AS x + Projections: + - Scan: y (4343416624) + Context: + Source: y AS y + Projections: + + Args: + expression: the expression to build the DAG from. + ctes: a dictionary that maps CTEs to their corresponding Step DAG by name. + + Returns: + A Step DAG corresponding to `expression`. + """ + ctes = ctes or {} + expression = expression.unnest() + with_ = expression.args.get("with_") + + # CTEs break the mold of scope and introduce themselves to all in the context. + if with_: + ctes = ctes.copy() + for cte in with_.expressions: + step = Step.from_expression(cte.this, ctes) + step.name = cte.alias + ctes[step.name] = step # type: ignore + + from_ = expression.args.get("from_") + + if isinstance(expression, exp.Select) and from_: + step = Scan.from_expression(from_.this, ctes) + elif isinstance(expression, exp.SetOperation): + step = SetOperation.from_expression(expression, ctes) + else: + step = Scan() + + joins = expression.args.get("joins") + + if joins: + join = Join.from_joins(joins, ctes) + join.name = step.name + join.source_name = step.name + join.add_dependency(step) + step = join + + projections = [] # final selects in this chain of steps representing a select + operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) + aggregations = {} + next_operand_name = name_sequence("_a_") + + def extract_agg_operands(expression): + agg_funcs = tuple(expression.find_all(exp.AggFunc)) + if agg_funcs: + aggregations[expression] = None + + for agg in agg_funcs: + for operand in agg.unnest_operands(): + if isinstance(operand, exp.Column): + continue + if operand not in operands: + operands[operand] = next_operand_name() + + operand.replace(exp.column(operands[operand], quoted=True)) + + return bool(agg_funcs) + + def set_ops_and_aggs(step): + step.operands = tuple( + alias(operand, alias_) for operand, alias_ in operands.items() + ) + step.aggregations = list(aggregations) + + for e in expression.expressions: + if e.find(exp.AggFunc): + projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) + extract_agg_operands(e) + else: + projections.append(e) + + where = expression.args.get("where") + + if where: + step.condition = where.this + + group = expression.args.get("group") + + if group or aggregations: + aggregate = Aggregate() + aggregate.source = step.name + aggregate.name = step.name + + having = expression.args.get("having") + + if having: + if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)): + aggregate.condition = exp.column("_h", step.name, quoted=True) + else: + aggregate.condition = having.this + + set_ops_and_aggs(aggregate) + + # give aggregates names and replace projections with references to them + aggregate.group = { + f"_g{i}": e for i, e in enumerate(group.expressions if group else []) + } + + intermediate: t.Dict[str | exp.Expression, str] = {} + for k, v in aggregate.group.items(): + intermediate[v] = k + if isinstance(v, exp.Column): + intermediate[v.name] = k + + for projection in projections: + for node in projection.walk(): + name = intermediate.get(node) + if name: + node.replace(exp.column(name, step.name)) + + if aggregate.condition: + for node in aggregate.condition.walk(): + name = intermediate.get(node) or intermediate.get(node.name) + if name: + node.replace(exp.column(name, step.name)) + + aggregate.add_dependency(step) + step = aggregate + else: + aggregate = None + + order = expression.args.get("order") + + if order: + if aggregate and isinstance(step, Aggregate): + for i, ordered in enumerate(order.expressions): + if extract_agg_operands( + exp.alias_(ordered.this, f"_o_{i}", quoted=True) + ): + ordered.this.replace( + exp.column(f"_o_{i}", step.name, quoted=True) + ) + + set_ops_and_aggs(aggregate) + + sort = Sort() + sort.name = step.name + sort.key = order.expressions + sort.add_dependency(step) + step = sort + + step.projections = projections + + if isinstance(expression, exp.Select) and expression.args.get("distinct"): + distinct = Aggregate() + distinct.source = step.name + distinct.name = step.name + distinct.group = { + e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name) + for e in projections or expression.expressions + } + distinct.add_dependency(step) + step = distinct + + limit = expression.args.get("limit") + + if limit: + step.limit = int(limit.text("expression")) + + return step + + def __init__(self) -> None: + self.name: t.Optional[str] = None + self.dependencies: t.Set[Step] = set() + self.dependents: t.Set[Step] = set() + self.projections: t.Sequence[exp.Expression] = [] + self.limit: float = math.inf + self.condition: t.Optional[exp.Expression] = None + + def add_dependency(self, dependency: Step) -> None: + self.dependencies.add(dependency) + dependency.dependents.add(self) + + def __repr__(self) -> str: + return self.to_s() + + def to_s(self, level: int = 0) -> str: + indent = " " * level + nested = f"{indent} " + + context = self._to_s(f"{nested} ") + + if context: + context = [f"{nested}Context:"] + context + + lines = [ + f"{indent}- {self.id}", + *context, + f"{nested}Projections:", + ] + + for expression in self.projections: + lines.append(f"{nested} - {expression.sql()}") + + if self.condition: + lines.append(f"{nested}Condition: {self.condition.sql()}") + + if self.limit is not math.inf: + lines.append(f"{nested}Limit: {self.limit}") + + if self.dependencies: + lines.append(f"{nested}Dependencies:") + for dependency in self.dependencies: + lines.append(" " + dependency.to_s(level + 1)) + + return "\n".join(lines) + + @property + def type_name(self) -> str: + return self.__class__.__name__ + + @property + def id(self) -> str: + name = self.name + name = f" {name}" if name else "" + return f"{self.type_name}:{name} ({id(self)})" + + def _to_s(self, _indent: str) -> t.List[str]: + return [] + + +class Scan(Step): + @classmethod + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: + table = expression + alias_ = expression.alias_or_name + + if isinstance(expression, exp.Subquery): + table = expression.this + step = Step.from_expression(table, ctes) + step.name = alias_ + return step + + step = Scan() + step.name = alias_ + step.source = expression + if ctes and table.name in ctes: + step.add_dependency(ctes[table.name]) + + return step + + def __init__(self) -> None: + super().__init__() + self.source: t.Optional[exp.Expression] = None + + def _to_s(self, indent: str) -> t.List[str]: + return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore + + +class Join(Step): + @classmethod + def from_joins( + cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Join: + step = Join() + + for join in joins: + source_key, join_key, condition = join_condition(join) + step.joins[join.alias_or_name] = { + "side": join.side, # type: ignore + "join_key": join_key, + "source_key": source_key, + "condition": condition, + } + + step.add_dependency(Scan.from_expression(join.this, ctes)) + + return step + + def __init__(self) -> None: + super().__init__() + self.source_name: t.Optional[str] = None + self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {} + + def _to_s(self, indent: str) -> t.List[str]: + lines = [f"{indent}Source: {self.source_name or self.name}"] + for name, join in self.joins.items(): + lines.append(f"{indent}{name}: {join['side'] or 'INNER'}") + join_key = ", ".join( + str(key) for key in t.cast(list, join.get("join_key") or []) + ) + if join_key: + lines.append(f"{indent}Key: {join_key}") + if join.get("condition"): + lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore + return lines + + +class Aggregate(Step): + def __init__(self) -> None: + super().__init__() + self.aggregations: t.List[exp.Expression] = [] + self.operands: t.Tuple[exp.Expression, ...] = () + self.group: t.Dict[str, exp.Expression] = {} + self.source: t.Optional[str] = None + + def _to_s(self, indent: str) -> t.List[str]: + lines = [f"{indent}Aggregations:"] + + for expression in self.aggregations: + lines.append(f"{indent} - {expression.sql()}") + + if self.group: + lines.append(f"{indent}Group:") + for expression in self.group.values(): + lines.append(f"{indent} - {expression.sql()}") + if self.condition: + lines.append(f"{indent}Having:") + lines.append(f"{indent} - {self.condition.sql()}") + if self.operands: + lines.append(f"{indent}Operands:") + for expression in self.operands: + lines.append(f"{indent} - {expression.sql()}") + + return lines + + +class Sort(Step): + def __init__(self) -> None: + super().__init__() + self.key = None + + def _to_s(self, indent: str) -> t.List[str]: + lines = [f"{indent}Key:"] + + for expression in self.key: # type: ignore + lines.append(f"{indent} - {expression.sql()}") + + return lines + + +class SetOperation(Step): + def __init__( + self, + op: t.Type[exp.Expression], + left: str | None, + right: str | None, + distinct: bool = False, + ) -> None: + super().__init__() + self.op = op + self.left = left + self.right = right + self.distinct = distinct + + @classmethod + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> SetOperation: + assert isinstance(expression, exp.SetOperation) + + left = Step.from_expression(expression.left, ctes) + # SELECT 1 UNION SELECT 2 <-- these subqueries don't have names + left.name = left.name or "left" + right = Step.from_expression(expression.right, ctes) + right.name = right.name or "right" + step = cls( + op=expression.__class__, + left=left.name, + right=right.name, + distinct=bool(expression.args.get("distinct")), + ) + + step.add_dependency(left) + step.add_dependency(right) + + limit = expression.args.get("limit") + + if limit: + step.limit = int(limit.text("expression")) + + return step + + def _to_s(self, indent: str) -> t.List[str]: + lines = [] + if self.distinct: + lines.append(f"{indent}Distinct: {self.distinct}") + return lines + + @property + def type_name(self) -> str: + return self.op.__name__ diff --git a/third_party/bigframes_vendored/sqlglot/py.typed b/third_party/bigframes_vendored/sqlglot/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/bigframes_vendored/sqlglot/schema.py b/third_party/bigframes_vendored/sqlglot/schema.py new file mode 100644 index 0000000000..0f291735c8 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/schema.py @@ -0,0 +1,639 @@ +from __future__ import annotations + +import abc +import typing as t + +from bigframes_vendored.sqlglot import expressions as exp +from bigframes_vendored.sqlglot.dialects.dialect import Dialect +from bigframes_vendored.sqlglot.errors import SchemaError +from bigframes_vendored.sqlglot.helper import dict_depth, first +from bigframes_vendored.sqlglot.trie import in_trie, new_trie, TrieResult + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + + ColumnMapping = t.Union[t.Dict, str, t.List] + + +class Schema(abc.ABC): + """Abstract base class for database schemas""" + + @property + def dialect(self) -> t.Optional[Dialect]: + """ + Returns None by default. Subclasses that require dialect-specific + behavior should override this property. + """ + return None + + @abc.abstractmethod + def add_table( + self, + table: exp.Table | str, + column_mapping: t.Optional[ColumnMapping] = None, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + match_depth: bool = True, + ) -> None: + """ + Register or update a table. Some implementing classes may require column information to also be provided. + The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. + + Args: + table: the `Table` expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. + dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + match_depth: whether to enforce that the table must match the schema's depth or not. + """ + + @abc.abstractmethod + def column_names( + self, + table: exp.Table | str, + only_visible: bool = False, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> t.Sequence[str]: + """ + Get the column names for a table. + + Args: + table: the `Table` expression instance. + only_visible: whether to include invisible columns. + dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + + Returns: + The sequence of column names. + """ + + @abc.abstractmethod + def get_column_type( + self, + table: exp.Table | str, + column: exp.Column | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> exp.DataType: + """ + Get the `sqlglot.exp.DataType` type of a column in the schema. + + Args: + table: the source table. + column: the target column. + dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + + Returns: + The resulting column type. + """ + + def has_column( + self, + table: exp.Table | str, + column: exp.Column | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> bool: + """ + Returns whether `column` appears in `table`'s schema. + + Args: + table: the source table. + column: the target column. + dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + + Returns: + True if the column appears in the schema, False otherwise. + """ + name = column if isinstance(column, str) else column.name + return name in self.column_names(table, dialect=dialect, normalize=normalize) + + @property + @abc.abstractmethod + def supported_table_args(self) -> t.Tuple[str, ...]: + """ + Table arguments this schema support, e.g. `("this", "db", "catalog")` + """ + + @property + def empty(self) -> bool: + """Returns whether the schema is empty.""" + return True + + +class AbstractMappingSchema: + def __init__( + self, + mapping: t.Optional[t.Dict] = None, + ) -> None: + self.mapping = mapping or {} + self.mapping_trie = new_trie( + tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) + ) + self._supported_table_args: t.Tuple[str, ...] = tuple() + + @property + def empty(self) -> bool: + return not self.mapping + + def depth(self) -> int: + return dict_depth(self.mapping) + + @property + def supported_table_args(self) -> t.Tuple[str, ...]: + if not self._supported_table_args and self.mapping: + depth = self.depth() + + if not depth: # None + self._supported_table_args = tuple() + elif 1 <= depth <= 3: + self._supported_table_args = exp.TABLE_PARTS[:depth] + else: + raise SchemaError(f"Invalid mapping shape. Depth: {depth}") + + return self._supported_table_args + + def table_parts(self, table: exp.Table) -> t.List[str]: + return [part.name for part in reversed(table.parts)] + + def find( + self, + table: exp.Table, + raise_on_missing: bool = True, + ensure_data_types: bool = False, + ) -> t.Optional[t.Any]: + """ + Returns the schema of a given table. + + Args: + table: the target table. + raise_on_missing: whether to raise in case the schema is not found. + ensure_data_types: whether to convert `str` types to their `DataType` equivalents. + + Returns: + The schema of the target table. + """ + parts = self.table_parts(table)[0 : len(self.supported_table_args)] + value, trie = in_trie(self.mapping_trie, parts) + + if value == TrieResult.FAILED: + return None + + if value == TrieResult.PREFIX: + possibilities = flatten_schema(trie) + + if len(possibilities) == 1: + parts.extend(possibilities[0]) + else: + message = ", ".join(".".join(parts) for parts in possibilities) + if raise_on_missing: + raise SchemaError(f"Ambiguous mapping for {table}: {message}.") + return None + + return self.nested_get(parts, raise_on_missing=raise_on_missing) + + def nested_get( + self, + parts: t.Sequence[str], + d: t.Optional[t.Dict] = None, + raise_on_missing=True, + ) -> t.Optional[t.Any]: + return nested_get( + d or self.mapping, + *zip(self.supported_table_args, reversed(parts)), + raise_on_missing=raise_on_missing, + ) + + +class MappingSchema(AbstractMappingSchema, Schema): + """ + Schema based on a nested mapping. + + Args: + schema: Mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + 4. None - Tables will be added later + visible: Optional mapping of which columns in the schema are visible. If not provided, all columns + are assumed to be visible. The nesting should mirror that of the schema: + 1. {table: set(*cols)}} + 2. {db: {table: set(*cols)}}} + 3. {catalog: {db: {table: set(*cols)}}}} + dialect: The dialect to be used for custom type mappings & parsing string arguments. + normalize: Whether to normalize identifier names according to the given dialect or not. + """ + + def __init__( + self, + schema: t.Optional[t.Dict] = None, + visible: t.Optional[t.Dict] = None, + dialect: DialectType = None, + normalize: bool = True, + ) -> None: + self.visible = {} if visible is None else visible + self.normalize = normalize + self._dialect = Dialect.get_or_raise(dialect) + self._type_mapping_cache: t.Dict[str, exp.DataType] = {} + self._depth = 0 + schema = {} if schema is None else schema + + super().__init__(self._normalize(schema) if self.normalize else schema) + + @property + def dialect(self) -> Dialect: + """Returns the dialect for this mapping schema.""" + return self._dialect + + @classmethod + def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: + return MappingSchema( + schema=mapping_schema.mapping, + visible=mapping_schema.visible, + dialect=mapping_schema.dialect, + normalize=mapping_schema.normalize, + ) + + def find( + self, + table: exp.Table, + raise_on_missing: bool = True, + ensure_data_types: bool = False, + ) -> t.Optional[t.Any]: + schema = super().find( + table, + raise_on_missing=raise_on_missing, + ensure_data_types=ensure_data_types, + ) + if ensure_data_types and isinstance(schema, dict): + schema = { + col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype + for col, dtype in schema.items() + } + + return schema + + def copy(self, **kwargs) -> MappingSchema: + return MappingSchema( + **{ # type: ignore + "schema": self.mapping.copy(), + "visible": self.visible.copy(), + "dialect": self.dialect, + "normalize": self.normalize, + **kwargs, + } + ) + + def add_table( + self, + table: exp.Table | str, + column_mapping: t.Optional[ColumnMapping] = None, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + match_depth: bool = True, + ) -> None: + """ + Register or update a table. Updates are only performed if a new column mapping is provided. + The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. + + Args: + table: the `Table` expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. + dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + match_depth: whether to enforce that the table must match the schema's depth or not. + """ + normalized_table = self._normalize_table( + table, dialect=dialect, normalize=normalize + ) + + if ( + match_depth + and not self.empty + and len(normalized_table.parts) != self.depth() + ): + raise SchemaError( + f"Table {normalized_table.sql(dialect=self.dialect)} must match the " + f"schema's nesting level: {self.depth()}." + ) + + normalized_column_mapping = { + self._normalize_name(key, dialect=dialect, normalize=normalize): value + for key, value in ensure_column_mapping(column_mapping).items() + } + + schema = self.find(normalized_table, raise_on_missing=False) + if schema and not normalized_column_mapping: + return + + parts = self.table_parts(normalized_table) + + nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) + new_trie([parts], self.mapping_trie) + + def column_names( + self, + table: exp.Table | str, + only_visible: bool = False, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> t.List[str]: + normalized_table = self._normalize_table( + table, dialect=dialect, normalize=normalize + ) + + schema = self.find(normalized_table) + if schema is None: + return [] + + if not only_visible or not self.visible: + return list(schema) + + visible = ( + self.nested_get(self.table_parts(normalized_table), self.visible) or [] + ) + return [col for col in schema if col in visible] + + def get_column_type( + self, + table: exp.Table | str, + column: exp.Column | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> exp.DataType: + normalized_table = self._normalize_table( + table, dialect=dialect, normalize=normalize + ) + + normalized_column_name = self._normalize_name( + column if isinstance(column, str) else column.this, + dialect=dialect, + normalize=normalize, + ) + + table_schema = self.find(normalized_table, raise_on_missing=False) + if table_schema: + column_type = table_schema.get(normalized_column_name) + + if isinstance(column_type, exp.DataType): + return column_type + elif isinstance(column_type, str): + return self._to_data_type(column_type, dialect=dialect) + + return exp.DataType.build("unknown") + + def has_column( + self, + table: exp.Table | str, + column: exp.Column | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> bool: + normalized_table = self._normalize_table( + table, dialect=dialect, normalize=normalize + ) + + normalized_column_name = self._normalize_name( + column if isinstance(column, str) else column.this, + dialect=dialect, + normalize=normalize, + ) + + table_schema = self.find(normalized_table, raise_on_missing=False) + return normalized_column_name in table_schema if table_schema else False + + def _normalize(self, schema: t.Dict) -> t.Dict: + """ + Normalizes all identifiers in the schema. + + Args: + schema: the schema to normalize. + + Returns: + The normalized schema mapping. + """ + normalized_mapping: t.Dict = {} + flattened_schema = flatten_schema(schema) + error_msg = "Table {} must match the schema's nesting level: {}." + + for keys in flattened_schema: + columns = nested_get(schema, *zip(keys, keys)) + + if not isinstance(columns, dict): + raise SchemaError( + error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])) + ) + if not columns: + raise SchemaError( + f"Table {'.'.join(keys[:-1])} must have at least one column" + ) + if isinstance(first(columns.values()), dict): + raise SchemaError( + error_msg.format( + ".".join(keys + flatten_schema(columns)[0]), + len(flattened_schema[0]), + ), + ) + + normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] + for column_name, column_type in columns.items(): + nested_set( + normalized_mapping, + normalized_keys + [self._normalize_name(column_name)], + column_type, + ) + + return normalized_mapping + + def _normalize_table( + self, + table: exp.Table | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> exp.Table: + dialect = dialect or self.dialect + normalize = self.normalize if normalize is None else normalize + + normalized_table = exp.maybe_parse( + table, into=exp.Table, dialect=dialect, copy=normalize + ) + + if normalize: + for part in normalized_table.parts: + if isinstance(part, exp.Identifier): + part.replace( + normalize_name( + part, dialect=dialect, is_table=True, normalize=normalize + ) + ) + + return normalized_table + + def _normalize_name( + self, + name: str | exp.Identifier, + dialect: DialectType = None, + is_table: bool = False, + normalize: t.Optional[bool] = None, + ) -> str: + return normalize_name( + name, + dialect=dialect or self.dialect, + is_table=is_table, + normalize=self.normalize if normalize is None else normalize, + ).name + + def depth(self) -> int: + if not self.empty and not self._depth: + # The columns themselves are a mapping, but we don't want to include those + self._depth = super().depth() - 1 + return self._depth + + def _to_data_type( + self, schema_type: str, dialect: DialectType = None + ) -> exp.DataType: + """ + Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. + + Args: + schema_type: the type we want to convert. + dialect: the SQL dialect that will be used to parse `schema_type`, if needed. + + Returns: + The resulting expression type. + """ + if schema_type not in self._type_mapping_cache: + dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect + udt = dialect.SUPPORTS_USER_DEFINED_TYPES + + try: + expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) + self._type_mapping_cache[schema_type] = expression + except AttributeError: + in_dialect = f" in dialect {dialect}" if dialect else "" + raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") + + return self._type_mapping_cache[schema_type] + + +def normalize_name( + identifier: str | exp.Identifier, + dialect: DialectType = None, + is_table: bool = False, + normalize: t.Optional[bool] = True, +) -> exp.Identifier: + if isinstance(identifier, str): + identifier = exp.parse_identifier(identifier, dialect=dialect) + + if not normalize: + return identifier + + # this is used for normalize_identifier, bigquery has special rules pertaining tables + identifier.meta["is_table"] = is_table + return Dialect.get_or_raise(dialect).normalize_identifier(identifier) + + +def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: + if isinstance(schema, Schema): + return schema + + return MappingSchema(schema, **kwargs) + + +def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: + if mapping is None: + return {} + elif isinstance(mapping, dict): + return mapping + elif isinstance(mapping, str): + col_name_type_strs = [x.strip() for x in mapping.split(",")] + return { + name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() + for name_type_str in col_name_type_strs + } + elif isinstance(mapping, list): + return {x.strip(): None for x in mapping} + + raise ValueError(f"Invalid mapping provided: {type(mapping)}") + + +def flatten_schema( + schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None +) -> t.List[t.List[str]]: + tables = [] + keys = keys or [] + depth = dict_depth(schema) - 1 if depth is None else depth + + for k, v in schema.items(): + if depth == 1 or not isinstance(v, dict): + tables.append(keys + [k]) + elif depth >= 2: + tables.extend(flatten_schema(v, depth - 1, keys + [k])) + + return tables + + +def nested_get( + d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True +) -> t.Optional[t.Any]: + """ + Get a value for a nested dictionary. + + Args: + d: the dictionary to search. + *path: tuples of (name, key), where: + `key` is the key in the dictionary to get. + `name` is a string to use in the error if `key` isn't found. + + Returns: + The value or None if it doesn't exist. + """ + for name, key in path: + d = d.get(key) # type: ignore + if d is None: + if raise_on_missing: + name = "table" if name == "this" else name + raise ValueError(f"Unknown {name}: {key}") + return None + + return d + + +def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: + """ + In-place set a value for a nested dictionary + + Example: + >>> nested_set({}, ["top_key", "second_key"], "value") + {'top_key': {'second_key': 'value'}} + + >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") + {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} + + Args: + d: dictionary to update. + keys: the keys that makeup the path to `value`. + value: the value to set in the dictionary for the given key path. + + Returns: + The (possibly) updated dictionary. + """ + if not keys: + return d + + if len(keys) == 1: + d[keys[0]] = value + return d + + subd = d + for key in keys[:-1]: + if key not in subd: + subd = subd.setdefault(key, {}) + else: + subd = subd[key] + + subd[keys[-1]] = value + return d diff --git a/third_party/bigframes_vendored/sqlglot/serde.py b/third_party/bigframes_vendored/sqlglot/serde.py new file mode 100644 index 0000000000..43d71621d5 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/serde.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import expressions as exp + +INDEX = "i" +ARG_KEY = "k" +IS_ARRAY = "a" +CLASS = "c" +TYPE = "t" +COMMENTS = "o" +META = "m" +VALUE = "v" +DATA_TYPE = "DataType.Type" + + +def dump(expression: exp.Expression) -> t.List[t.Dict[str, t.Any]]: + """ + Dump an Expression into a JSON serializable List. + """ + i = 0 + payloads = [] + stack: t.List[t.Tuple[t.Any, t.Optional[int], t.Optional[str], bool]] = [ + (expression, None, None, False) + ] + + while stack: + node, index, arg_key, is_array = stack.pop() + + payload: t.Dict[str, t.Any] = {} + + if index is not None: + payload[INDEX] = index + if arg_key is not None: + payload[ARG_KEY] = arg_key + if is_array: + payload[IS_ARRAY] = is_array + + payloads.append(payload) + + if hasattr(node, "parent"): + klass = node.__class__.__qualname__ + + if node.__class__.__module__ != exp.__name__: + klass = f"{node.__module__}.{klass}" + + payload[CLASS] = klass + + if node.type: + payload[TYPE] = dump(node.type) + if node.comments: + payload[COMMENTS] = node.comments + if node._meta is not None: + payload[META] = node._meta + if node.args: + for k, vs in reversed(node.args.items()): + if type(vs) is list: + for v in reversed(vs): + stack.append((v, i, k, True)) + elif vs is not None: + stack.append((vs, i, k, False)) + elif type(node) is exp.DataType.Type: + payload[CLASS] = DATA_TYPE + payload[VALUE] = node.value + else: + payload[VALUE] = node + + i += 1 + + return payloads + + +@t.overload +def load(payloads: None) -> None: + ... + + +@t.overload +def load(payloads: t.List[t.Dict[str, t.Any]]) -> exp.Expression: + ... + + +def load(payloads): + """ + Load a list of dicts generated by dump into an Expression. + """ + + if not payloads: + return None + + payload, *tail = payloads + root = _load(payload) + nodes = [root] + for payload in tail: + node = _load(payload) + nodes.append(node) + parent = nodes[payload[INDEX]] + arg_key = payload[ARG_KEY] + + if payload.get(IS_ARRAY): + parent.append(arg_key, node) + else: + parent.set(arg_key, node) + + return root + + +def _load(payload: t.Dict[str, t.Any]) -> exp.Expression | exp.DataType.Type: + class_name = payload.get(CLASS) + + if not class_name: + return payload[VALUE] + if class_name == DATA_TYPE: + return exp.DataType.Type(payload[VALUE]) + + if "." in class_name: + module_path, class_name = class_name.rsplit(".", maxsplit=1) + module = __import__(module_path, fromlist=[class_name]) + else: + module = exp + + expression = getattr(module, class_name)() + expression.type = load(payload.get(TYPE)) + expression.comments = payload.get(COMMENTS) + expression._meta = payload.get(META) + return expression diff --git a/third_party/bigframes_vendored/sqlglot/time.py b/third_party/bigframes_vendored/sqlglot/time.py new file mode 100644 index 0000000000..ee3b80bfe1 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/time.py @@ -0,0 +1,687 @@ +import datetime +import typing as t + +# The generic time format is based on python time.strftime. +# https://docs.python.org/3/library/time.html#time.strftime +from bigframes_vendored.sqlglot.trie import in_trie, new_trie, TrieResult + + +def format_time( + string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None +) -> t.Optional[str]: + """ + Converts a time string given a mapping. + + Examples: + >>> format_time("%Y", {"%Y": "YYYY"}) + 'YYYY' + + Args: + mapping: dictionary of time format to target time format. + trie: optional trie, can be passed in for performance. + + Returns: + The converted time string. + """ + if not string: + return None + + start = 0 + end = 1 + size = len(string) + trie = trie or new_trie(mapping) + current = trie + chunks = [] + sym = None + + while end <= size: + chars = string[start:end] + result, current = in_trie(current, chars[-1]) + + if result == TrieResult.FAILED: + if sym: + end -= 1 + chars = sym + sym = None + else: + chars = chars[0] + end = start + 1 + + start += len(chars) + chunks.append(chars) + current = trie + elif result == TrieResult.EXISTS: + sym = chars + + end += 1 + + if result != TrieResult.FAILED and end > size: + chunks.append(chars) + + return "".join(mapping.get(chars, chars) for chars in chunks) + + +TIMEZONES = { + tz.lower() + for tz in ( + "Africa/Abidjan", + "Africa/Accra", + "Africa/Addis_Ababa", + "Africa/Algiers", + "Africa/Asmara", + "Africa/Asmera", + "Africa/Bamako", + "Africa/Bangui", + "Africa/Banjul", + "Africa/Bissau", + "Africa/Blantyre", + "Africa/Brazzaville", + "Africa/Bujumbura", + "Africa/Cairo", + "Africa/Casablanca", + "Africa/Ceuta", + "Africa/Conakry", + "Africa/Dakar", + "Africa/Dar_es_Salaam", + "Africa/Djibouti", + "Africa/Douala", + "Africa/El_Aaiun", + "Africa/Freetown", + "Africa/Gaborone", + "Africa/Harare", + "Africa/Johannesburg", + "Africa/Juba", + "Africa/Kampala", + "Africa/Khartoum", + "Africa/Kigali", + "Africa/Kinshasa", + "Africa/Lagos", + "Africa/Libreville", + "Africa/Lome", + "Africa/Luanda", + "Africa/Lubumbashi", + "Africa/Lusaka", + "Africa/Malabo", + "Africa/Maputo", + "Africa/Maseru", + "Africa/Mbabane", + "Africa/Mogadishu", + "Africa/Monrovia", + "Africa/Nairobi", + "Africa/Ndjamena", + "Africa/Niamey", + "Africa/Nouakchott", + "Africa/Ouagadougou", + "Africa/Porto-Novo", + "Africa/Sao_Tome", + "Africa/Timbuktu", + "Africa/Tripoli", + "Africa/Tunis", + "Africa/Windhoek", + "America/Adak", + "America/Anchorage", + "America/Anguilla", + "America/Antigua", + "America/Araguaina", + "America/Argentina/Buenos_Aires", + "America/Argentina/Catamarca", + "America/Argentina/ComodRivadavia", + "America/Argentina/Cordoba", + "America/Argentina/Jujuy", + "America/Argentina/La_Rioja", + "America/Argentina/Mendoza", + "America/Argentina/Rio_Gallegos", + "America/Argentina/Salta", + "America/Argentina/San_Juan", + "America/Argentina/San_Luis", + "America/Argentina/Tucuman", + "America/Argentina/Ushuaia", + "America/Aruba", + "America/Asuncion", + "America/Atikokan", + "America/Atka", + "America/Bahia", + "America/Bahia_Banderas", + "America/Barbados", + "America/Belem", + "America/Belize", + "America/Blanc-Sablon", + "America/Boa_Vista", + "America/Bogota", + "America/Boise", + "America/Buenos_Aires", + "America/Cambridge_Bay", + "America/Campo_Grande", + "America/Cancun", + "America/Caracas", + "America/Catamarca", + "America/Cayenne", + "America/Cayman", + "America/Chicago", + "America/Chihuahua", + "America/Ciudad_Juarez", + "America/Coral_Harbour", + "America/Cordoba", + "America/Costa_Rica", + "America/Creston", + "America/Cuiaba", + "America/Curacao", + "America/Danmarkshavn", + "America/Dawson", + "America/Dawson_Creek", + "America/Denver", + "America/Detroit", + "America/Dominica", + "America/Edmonton", + "America/Eirunepe", + "America/El_Salvador", + "America/Ensenada", + "America/Fort_Nelson", + "America/Fort_Wayne", + "America/Fortaleza", + "America/Glace_Bay", + "America/Godthab", + "America/Goose_Bay", + "America/Grand_Turk", + "America/Grenada", + "America/Guadeloupe", + "America/Guatemala", + "America/Guayaquil", + "America/Guyana", + "America/Halifax", + "America/Havana", + "America/Hermosillo", + "America/Indiana/Indianapolis", + "America/Indiana/Knox", + "America/Indiana/Marengo", + "America/Indiana/Petersburg", + "America/Indiana/Tell_City", + "America/Indiana/Vevay", + "America/Indiana/Vincennes", + "America/Indiana/Winamac", + "America/Indianapolis", + "America/Inuvik", + "America/Iqaluit", + "America/Jamaica", + "America/Jujuy", + "America/Juneau", + "America/Kentucky/Louisville", + "America/Kentucky/Monticello", + "America/Knox_IN", + "America/Kralendijk", + "America/La_Paz", + "America/Lima", + "America/Los_Angeles", + "America/Louisville", + "America/Lower_Princes", + "America/Maceio", + "America/Managua", + "America/Manaus", + "America/Marigot", + "America/Martinique", + "America/Matamoros", + "America/Mazatlan", + "America/Mendoza", + "America/Menominee", + "America/Merida", + "America/Metlakatla", + "America/Mexico_City", + "America/Miquelon", + "America/Moncton", + "America/Monterrey", + "America/Montevideo", + "America/Montreal", + "America/Montserrat", + "America/Nassau", + "America/New_York", + "America/Nipigon", + "America/Nome", + "America/Noronha", + "America/North_Dakota/Beulah", + "America/North_Dakota/Center", + "America/North_Dakota/New_Salem", + "America/Nuuk", + "America/Ojinaga", + "America/Panama", + "America/Pangnirtung", + "America/Paramaribo", + "America/Phoenix", + "America/Port-au-Prince", + "America/Port_of_Spain", + "America/Porto_Acre", + "America/Porto_Velho", + "America/Puerto_Rico", + "America/Punta_Arenas", + "America/Rainy_River", + "America/Rankin_Inlet", + "America/Recife", + "America/Regina", + "America/Resolute", + "America/Rio_Branco", + "America/Rosario", + "America/Santa_Isabel", + "America/Santarem", + "America/Santiago", + "America/Santo_Domingo", + "America/Sao_Paulo", + "America/Scoresbysund", + "America/Shiprock", + "America/Sitka", + "America/St_Barthelemy", + "America/St_Johns", + "America/St_Kitts", + "America/St_Lucia", + "America/St_Thomas", + "America/St_Vincent", + "America/Swift_Current", + "America/Tegucigalpa", + "America/Thule", + "America/Thunder_Bay", + "America/Tijuana", + "America/Toronto", + "America/Tortola", + "America/Vancouver", + "America/Virgin", + "America/Whitehorse", + "America/Winnipeg", + "America/Yakutat", + "America/Yellowknife", + "Antarctica/Casey", + "Antarctica/Davis", + "Antarctica/DumontDUrville", + "Antarctica/Macquarie", + "Antarctica/Mawson", + "Antarctica/McMurdo", + "Antarctica/Palmer", + "Antarctica/Rothera", + "Antarctica/South_Pole", + "Antarctica/Syowa", + "Antarctica/Troll", + "Antarctica/Vostok", + "Arctic/Longyearbyen", + "Asia/Aden", + "Asia/Almaty", + "Asia/Amman", + "Asia/Anadyr", + "Asia/Aqtau", + "Asia/Aqtobe", + "Asia/Ashgabat", + "Asia/Ashkhabad", + "Asia/Atyrau", + "Asia/Baghdad", + "Asia/Bahrain", + "Asia/Baku", + "Asia/Bangkok", + "Asia/Barnaul", + "Asia/Beirut", + "Asia/Bishkek", + "Asia/Brunei", + "Asia/Calcutta", + "Asia/Chita", + "Asia/Choibalsan", + "Asia/Chongqing", + "Asia/Chungking", + "Asia/Colombo", + "Asia/Dacca", + "Asia/Damascus", + "Asia/Dhaka", + "Asia/Dili", + "Asia/Dubai", + "Asia/Dushanbe", + "Asia/Famagusta", + "Asia/Gaza", + "Asia/Harbin", + "Asia/Hebron", + "Asia/Ho_Chi_Minh", + "Asia/Hong_Kong", + "Asia/Hovd", + "Asia/Irkutsk", + "Asia/Istanbul", + "Asia/Jakarta", + "Asia/Jayapura", + "Asia/Jerusalem", + "Asia/Kabul", + "Asia/Kamchatka", + "Asia/Karachi", + "Asia/Kashgar", + "Asia/Kathmandu", + "Asia/Katmandu", + "Asia/Khandyga", + "Asia/Kolkata", + "Asia/Krasnoyarsk", + "Asia/Kuala_Lumpur", + "Asia/Kuching", + "Asia/Kuwait", + "Asia/Macao", + "Asia/Macau", + "Asia/Magadan", + "Asia/Makassar", + "Asia/Manila", + "Asia/Muscat", + "Asia/Nicosia", + "Asia/Novokuznetsk", + "Asia/Novosibirsk", + "Asia/Omsk", + "Asia/Oral", + "Asia/Phnom_Penh", + "Asia/Pontianak", + "Asia/Pyongyang", + "Asia/Qatar", + "Asia/Qostanay", + "Asia/Qyzylorda", + "Asia/Rangoon", + "Asia/Riyadh", + "Asia/Saigon", + "Asia/Sakhalin", + "Asia/Samarkand", + "Asia/Seoul", + "Asia/Shanghai", + "Asia/Singapore", + "Asia/Srednekolymsk", + "Asia/Taipei", + "Asia/Tashkent", + "Asia/Tbilisi", + "Asia/Tehran", + "Asia/Tel_Aviv", + "Asia/Thimbu", + "Asia/Thimphu", + "Asia/Tokyo", + "Asia/Tomsk", + "Asia/Ujung_Pandang", + "Asia/Ulaanbaatar", + "Asia/Ulan_Bator", + "Asia/Urumqi", + "Asia/Ust-Nera", + "Asia/Vientiane", + "Asia/Vladivostok", + "Asia/Yakutsk", + "Asia/Yangon", + "Asia/Yekaterinburg", + "Asia/Yerevan", + "Atlantic/Azores", + "Atlantic/Bermuda", + "Atlantic/Canary", + "Atlantic/Cape_Verde", + "Atlantic/Faeroe", + "Atlantic/Faroe", + "Atlantic/Jan_Mayen", + "Atlantic/Madeira", + "Atlantic/Reykjavik", + "Atlantic/South_Georgia", + "Atlantic/St_Helena", + "Atlantic/Stanley", + "Australia/ACT", + "Australia/Adelaide", + "Australia/Brisbane", + "Australia/Broken_Hill", + "Australia/Canberra", + "Australia/Currie", + "Australia/Darwin", + "Australia/Eucla", + "Australia/Hobart", + "Australia/LHI", + "Australia/Lindeman", + "Australia/Lord_Howe", + "Australia/Melbourne", + "Australia/NSW", + "Australia/North", + "Australia/Perth", + "Australia/Queensland", + "Australia/South", + "Australia/Sydney", + "Australia/Tasmania", + "Australia/Victoria", + "Australia/West", + "Australia/Yancowinna", + "Brazil/Acre", + "Brazil/DeNoronha", + "Brazil/East", + "Brazil/West", + "CET", + "CST6CDT", + "Canada/Atlantic", + "Canada/Central", + "Canada/Eastern", + "Canada/Mountain", + "Canada/Newfoundland", + "Canada/Pacific", + "Canada/Saskatchewan", + "Canada/Yukon", + "Chile/Continental", + "Chile/EasterIsland", + "Cuba", + "EET", + "EST", + "EST5EDT", + "Egypt", + "Eire", + "Etc/GMT", + "Etc/GMT+0", + "Etc/GMT+1", + "Etc/GMT+10", + "Etc/GMT+11", + "Etc/GMT+12", + "Etc/GMT+2", + "Etc/GMT+3", + "Etc/GMT+4", + "Etc/GMT+5", + "Etc/GMT+6", + "Etc/GMT+7", + "Etc/GMT+8", + "Etc/GMT+9", + "Etc/GMT-0", + "Etc/GMT-1", + "Etc/GMT-10", + "Etc/GMT-11", + "Etc/GMT-12", + "Etc/GMT-13", + "Etc/GMT-14", + "Etc/GMT-2", + "Etc/GMT-3", + "Etc/GMT-4", + "Etc/GMT-5", + "Etc/GMT-6", + "Etc/GMT-7", + "Etc/GMT-8", + "Etc/GMT-9", + "Etc/GMT0", + "Etc/Greenwich", + "Etc/UCT", + "Etc/UTC", + "Etc/Universal", + "Etc/Zulu", + "Europe/Amsterdam", + "Europe/Andorra", + "Europe/Astrakhan", + "Europe/Athens", + "Europe/Belfast", + "Europe/Belgrade", + "Europe/Berlin", + "Europe/Bratislava", + "Europe/Brussels", + "Europe/Bucharest", + "Europe/Budapest", + "Europe/Busingen", + "Europe/Chisinau", + "Europe/Copenhagen", + "Europe/Dublin", + "Europe/Gibraltar", + "Europe/Guernsey", + "Europe/Helsinki", + "Europe/Isle_of_Man", + "Europe/Istanbul", + "Europe/Jersey", + "Europe/Kaliningrad", + "Europe/Kiev", + "Europe/Kirov", + "Europe/Kyiv", + "Europe/Lisbon", + "Europe/Ljubljana", + "Europe/London", + "Europe/Luxembourg", + "Europe/Madrid", + "Europe/Malta", + "Europe/Mariehamn", + "Europe/Minsk", + "Europe/Monaco", + "Europe/Moscow", + "Europe/Nicosia", + "Europe/Oslo", + "Europe/Paris", + "Europe/Podgorica", + "Europe/Prague", + "Europe/Riga", + "Europe/Rome", + "Europe/Samara", + "Europe/San_Marino", + "Europe/Sarajevo", + "Europe/Saratov", + "Europe/Simferopol", + "Europe/Skopje", + "Europe/Sofia", + "Europe/Stockholm", + "Europe/Tallinn", + "Europe/Tirane", + "Europe/Tiraspol", + "Europe/Ulyanovsk", + "Europe/Uzhgorod", + "Europe/Vaduz", + "Europe/Vatican", + "Europe/Vienna", + "Europe/Vilnius", + "Europe/Volgograd", + "Europe/Warsaw", + "Europe/Zagreb", + "Europe/Zaporozhye", + "Europe/Zurich", + "GB", + "GB-Eire", + "GMT", + "GMT+0", + "GMT-0", + "GMT0", + "Greenwich", + "HST", + "Hongkong", + "Iceland", + "Indian/Antananarivo", + "Indian/Chagos", + "Indian/Christmas", + "Indian/Cocos", + "Indian/Comoro", + "Indian/Kerguelen", + "Indian/Mahe", + "Indian/Maldives", + "Indian/Mauritius", + "Indian/Mayotte", + "Indian/Reunion", + "Iran", + "Israel", + "Jamaica", + "Japan", + "Kwajalein", + "Libya", + "MET", + "MST", + "MST7MDT", + "Mexico/BajaNorte", + "Mexico/BajaSur", + "Mexico/General", + "NZ", + "NZ-CHAT", + "Navajo", + "PRC", + "PST8PDT", + "Pacific/Apia", + "Pacific/Auckland", + "Pacific/Bougainville", + "Pacific/Chatham", + "Pacific/Chuuk", + "Pacific/Easter", + "Pacific/Efate", + "Pacific/Enderbury", + "Pacific/Fakaofo", + "Pacific/Fiji", + "Pacific/Funafuti", + "Pacific/Galapagos", + "Pacific/Gambier", + "Pacific/Guadalcanal", + "Pacific/Guam", + "Pacific/Honolulu", + "Pacific/Johnston", + "Pacific/Kanton", + "Pacific/Kiritimati", + "Pacific/Kosrae", + "Pacific/Kwajalein", + "Pacific/Majuro", + "Pacific/Marquesas", + "Pacific/Midway", + "Pacific/Nauru", + "Pacific/Niue", + "Pacific/Norfolk", + "Pacific/Noumea", + "Pacific/Pago_Pago", + "Pacific/Palau", + "Pacific/Pitcairn", + "Pacific/Pohnpei", + "Pacific/Ponape", + "Pacific/Port_Moresby", + "Pacific/Rarotonga", + "Pacific/Saipan", + "Pacific/Samoa", + "Pacific/Tahiti", + "Pacific/Tarawa", + "Pacific/Tongatapu", + "Pacific/Truk", + "Pacific/Wake", + "Pacific/Wallis", + "Pacific/Yap", + "Poland", + "Portugal", + "ROC", + "ROK", + "Singapore", + "Turkey", + "UCT", + "US/Alaska", + "US/Aleutian", + "US/Arizona", + "US/Central", + "US/East-Indiana", + "US/Eastern", + "US/Hawaii", + "US/Indiana-Starke", + "US/Michigan", + "US/Mountain", + "US/Pacific", + "US/Samoa", + "UTC", + "Universal", + "W-SU", + "WET", + "Zulu", + ) +} + + +def subsecond_precision(timestamp_literal: str) -> int: + """ + Given an ISO-8601 timestamp literal, eg '2023-01-01 12:13:14.123456+00:00' + figure out its subsecond precision so we can construct types like DATETIME(6) + + Note that in practice, this is either 3 or 6 digits (3 = millisecond precision, 6 = microsecond precision) + - 6 is the maximum because strftime's '%f' formats to microseconds and almost every database supports microsecond precision in timestamps + - Except Presto/Trino which in most cases only supports millisecond precision but will still honour '%f' and format to microseconds (replacing the remaining 3 digits with 0's) + - Python prior to 3.11 only supports 0, 3 or 6 digits in a timestamp literal. Any other amounts will throw a 'ValueError: Invalid isoformat string:' error + """ + try: + parsed = datetime.datetime.fromisoformat(timestamp_literal) + subsecond_digit_count = len(str(parsed.microsecond).rstrip("0")) + precision = 0 + if subsecond_digit_count > 3: + precision = 6 + elif subsecond_digit_count > 0: + precision = 3 + return precision + except ValueError: + return 0 diff --git a/third_party/bigframes_vendored/sqlglot/tokens.py b/third_party/bigframes_vendored/sqlglot/tokens.py new file mode 100644 index 0000000000..b4258989ed --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/tokens.py @@ -0,0 +1,1638 @@ +from __future__ import annotations + +from enum import auto +import os +import typing as t + +from bigframes_vendored.sqlglot.errors import SqlglotError, TokenError +from bigframes_vendored.sqlglot.helper import AutoName +from bigframes_vendored.sqlglot.trie import in_trie, new_trie, TrieResult + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.dialects.dialect import DialectType + + +try: + from bigframes_vendored.sqlglotrs import Tokenizer as RsTokenizer # type: ignore + from bigframes_vendored.sqlglotrs import ( + TokenizerDialectSettings as RsTokenizerDialectSettings, + ) + from bigframes_vendored.sqlglotrs import TokenizerSettings as RsTokenizerSettings + from bigframes_vendored.sqlglotrs import TokenTypeSettings as RsTokenTypeSettings + + USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1" +except ImportError: + USE_RS_TOKENIZER = False + + +class TokenType(AutoName): + L_PAREN = auto() + R_PAREN = auto() + L_BRACKET = auto() + R_BRACKET = auto() + L_BRACE = auto() + R_BRACE = auto() + COMMA = auto() + DOT = auto() + DASH = auto() + PLUS = auto() + COLON = auto() + DOTCOLON = auto() + DCOLON = auto() + DCOLONDOLLAR = auto() + DCOLONPERCENT = auto() + DCOLONQMARK = auto() + DQMARK = auto() + SEMICOLON = auto() + STAR = auto() + BACKSLASH = auto() + SLASH = auto() + LT = auto() + LTE = auto() + GT = auto() + GTE = auto() + NOT = auto() + EQ = auto() + NEQ = auto() + NULLSAFE_EQ = auto() + COLON_EQ = auto() + COLON_GT = auto() + NCOLON_GT = auto() + AND = auto() + OR = auto() + AMP = auto() + DPIPE = auto() + PIPE_GT = auto() + PIPE = auto() + PIPE_SLASH = auto() + DPIPE_SLASH = auto() + CARET = auto() + CARET_AT = auto() + TILDA = auto() + ARROW = auto() + DARROW = auto() + FARROW = auto() + HASH = auto() + HASH_ARROW = auto() + DHASH_ARROW = auto() + LR_ARROW = auto() + DAT = auto() + LT_AT = auto() + AT_GT = auto() + DOLLAR = auto() + PARAMETER = auto() + SESSION = auto() + SESSION_PARAMETER = auto() + SESSION_USER = auto() + DAMP = auto() + AMP_LT = auto() + AMP_GT = auto() + ADJACENT = auto() + XOR = auto() + DSTAR = auto() + QMARK_AMP = auto() + QMARK_PIPE = auto() + HASH_DASH = auto() + EXCLAMATION = auto() + + URI_START = auto() + + BLOCK_START = auto() + BLOCK_END = auto() + + SPACE = auto() + BREAK = auto() + + STRING = auto() + NUMBER = auto() + IDENTIFIER = auto() + DATABASE = auto() + COLUMN = auto() + COLUMN_DEF = auto() + SCHEMA = auto() + TABLE = auto() + WAREHOUSE = auto() + STAGE = auto() + STREAMLIT = auto() + VAR = auto() + BIT_STRING = auto() + HEX_STRING = auto() + BYTE_STRING = auto() + NATIONAL_STRING = auto() + RAW_STRING = auto() + HEREDOC_STRING = auto() + UNICODE_STRING = auto() + + # types + BIT = auto() + BOOLEAN = auto() + TINYINT = auto() + UTINYINT = auto() + SMALLINT = auto() + USMALLINT = auto() + MEDIUMINT = auto() + UMEDIUMINT = auto() + INT = auto() + UINT = auto() + BIGINT = auto() + UBIGINT = auto() + BIGNUM = auto() # unlimited precision int + INT128 = auto() + UINT128 = auto() + INT256 = auto() + UINT256 = auto() + FLOAT = auto() + DOUBLE = auto() + UDOUBLE = auto() + DECIMAL = auto() + DECIMAL32 = auto() + DECIMAL64 = auto() + DECIMAL128 = auto() + DECIMAL256 = auto() + DECFLOAT = auto() + UDECIMAL = auto() + BIGDECIMAL = auto() + CHAR = auto() + NCHAR = auto() + VARCHAR = auto() + NVARCHAR = auto() + BPCHAR = auto() + TEXT = auto() + MEDIUMTEXT = auto() + LONGTEXT = auto() + BLOB = auto() + MEDIUMBLOB = auto() + LONGBLOB = auto() + TINYBLOB = auto() + TINYTEXT = auto() + NAME = auto() + BINARY = auto() + VARBINARY = auto() + JSON = auto() + JSONB = auto() + TIME = auto() + TIMETZ = auto() + TIME_NS = auto() + TIMESTAMP = auto() + TIMESTAMPTZ = auto() + TIMESTAMPLTZ = auto() + TIMESTAMPNTZ = auto() + TIMESTAMP_S = auto() + TIMESTAMP_MS = auto() + TIMESTAMP_NS = auto() + DATETIME = auto() + DATETIME2 = auto() + DATETIME64 = auto() + SMALLDATETIME = auto() + DATE = auto() + DATE32 = auto() + INT4RANGE = auto() + INT4MULTIRANGE = auto() + INT8RANGE = auto() + INT8MULTIRANGE = auto() + NUMRANGE = auto() + NUMMULTIRANGE = auto() + TSRANGE = auto() + TSMULTIRANGE = auto() + TSTZRANGE = auto() + TSTZMULTIRANGE = auto() + DATERANGE = auto() + DATEMULTIRANGE = auto() + UUID = auto() + GEOGRAPHY = auto() + GEOGRAPHYPOINT = auto() + NULLABLE = auto() + GEOMETRY = auto() + POINT = auto() + RING = auto() + LINESTRING = auto() + LOCALTIME = auto() + LOCALTIMESTAMP = auto() + MULTILINESTRING = auto() + POLYGON = auto() + MULTIPOLYGON = auto() + HLLSKETCH = auto() + HSTORE = auto() + SUPER = auto() + SERIAL = auto() + SMALLSERIAL = auto() + BIGSERIAL = auto() + XML = auto() + YEAR = auto() + USERDEFINED = auto() + MONEY = auto() + SMALLMONEY = auto() + ROWVERSION = auto() + IMAGE = auto() + VARIANT = auto() + OBJECT = auto() + INET = auto() + IPADDRESS = auto() + IPPREFIX = auto() + IPV4 = auto() + IPV6 = auto() + ENUM = auto() + ENUM8 = auto() + ENUM16 = auto() + FIXEDSTRING = auto() + LOWCARDINALITY = auto() + NESTED = auto() + AGGREGATEFUNCTION = auto() + SIMPLEAGGREGATEFUNCTION = auto() + TDIGEST = auto() + UNKNOWN = auto() + VECTOR = auto() + DYNAMIC = auto() + VOID = auto() + + # keywords + ALIAS = auto() + ALTER = auto() + ALL = auto() + ANTI = auto() + ANY = auto() + APPLY = auto() + ARRAY = auto() + ASC = auto() + ASOF = auto() + ATTACH = auto() + AUTO_INCREMENT = auto() + BEGIN = auto() + BETWEEN = auto() + BULK_COLLECT_INTO = auto() + CACHE = auto() + CASE = auto() + CHARACTER_SET = auto() + CLUSTER_BY = auto() + COLLATE = auto() + COMMAND = auto() + COMMENT = auto() + COMMIT = auto() + CONNECT_BY = auto() + CONSTRAINT = auto() + COPY = auto() + CREATE = auto() + CROSS = auto() + CUBE = auto() + CURRENT_DATE = auto() + CURRENT_DATETIME = auto() + CURRENT_SCHEMA = auto() + CURRENT_TIME = auto() + CURRENT_TIMESTAMP = auto() + CURRENT_USER = auto() + CURRENT_ROLE = auto() + CURRENT_CATALOG = auto() + DECLARE = auto() + DEFAULT = auto() + DELETE = auto() + DESC = auto() + DESCRIBE = auto() + DETACH = auto() + DICTIONARY = auto() + DISTINCT = auto() + DISTRIBUTE_BY = auto() + DIV = auto() + DROP = auto() + ELSE = auto() + END = auto() + ESCAPE = auto() + EXCEPT = auto() + EXECUTE = auto() + EXISTS = auto() + FALSE = auto() + FETCH = auto() + FILE = auto() + FILE_FORMAT = auto() + FILTER = auto() + FINAL = auto() + FIRST = auto() + FOR = auto() + FORCE = auto() + FOREIGN_KEY = auto() + FORMAT = auto() + FROM = auto() + FULL = auto() + FUNCTION = auto() + GET = auto() + GLOB = auto() + GLOBAL = auto() + GRANT = auto() + GROUP_BY = auto() + GROUPING_SETS = auto() + HAVING = auto() + HINT = auto() + IGNORE = auto() + ILIKE = auto() + IN = auto() + INDEX = auto() + INDEXED_BY = auto() + INNER = auto() + INSERT = auto() + INSTALL = auto() + INTERSECT = auto() + INTERVAL = auto() + INTO = auto() + INTRODUCER = auto() + IRLIKE = auto() + IS = auto() + ISNULL = auto() + JOIN = auto() + JOIN_MARKER = auto() + KEEP = auto() + KEY = auto() + KILL = auto() + LANGUAGE = auto() + LATERAL = auto() + LEFT = auto() + LIKE = auto() + LIMIT = auto() + LIST = auto() + LOAD = auto() + LOCK = auto() + MAP = auto() + MATCH = auto() + MATCH_CONDITION = auto() + MATCH_RECOGNIZE = auto() + MEMBER_OF = auto() + MERGE = auto() + MOD = auto() + MODEL = auto() + NATURAL = auto() + NEXT = auto() + NOTHING = auto() + NOTNULL = auto() + NULL = auto() + OBJECT_IDENTIFIER = auto() + OFFSET = auto() + ON = auto() + ONLY = auto() + OPERATOR = auto() + ORDER_BY = auto() + ORDER_SIBLINGS_BY = auto() + ORDERED = auto() + ORDINALITY = auto() + OUTER = auto() + OVER = auto() + OVERLAPS = auto() + OVERWRITE = auto() + PARTITION = auto() + PARTITION_BY = auto() + PERCENT = auto() + PIVOT = auto() + PLACEHOLDER = auto() + POSITIONAL = auto() + PRAGMA = auto() + PREWHERE = auto() + PRIMARY_KEY = auto() + PROCEDURE = auto() + PROPERTIES = auto() + PSEUDO_TYPE = auto() + PUT = auto() + QUALIFY = auto() + QUOTE = auto() + QDCOLON = auto() + RANGE = auto() + RECURSIVE = auto() + REFRESH = auto() + RENAME = auto() + REPLACE = auto() + RETURNING = auto() + REVOKE = auto() + REFERENCES = auto() + RIGHT = auto() + RLIKE = auto() + ROLLBACK = auto() + ROLLUP = auto() + ROW = auto() + ROWS = auto() + SELECT = auto() + SEMI = auto() + SEPARATOR = auto() + SEQUENCE = auto() + SERDE_PROPERTIES = auto() + SET = auto() + SETTINGS = auto() + SHOW = auto() + SIMILAR_TO = auto() + SOME = auto() + SORT_BY = auto() + SOUNDS_LIKE = auto() + START_WITH = auto() + STORAGE_INTEGRATION = auto() + STRAIGHT_JOIN = auto() + STRUCT = auto() + SUMMARIZE = auto() + TABLE_SAMPLE = auto() + TAG = auto() + TEMPORARY = auto() + TOP = auto() + THEN = auto() + TRUE = auto() + TRUNCATE = auto() + UNCACHE = auto() + UNION = auto() + UNNEST = auto() + UNPIVOT = auto() + UPDATE = auto() + USE = auto() + USING = auto() + VALUES = auto() + VIEW = auto() + SEMANTIC_VIEW = auto() + VOLATILE = auto() + WHEN = auto() + WHERE = auto() + WINDOW = auto() + WITH = auto() + UNIQUE = auto() + UTC_DATE = auto() + UTC_TIME = auto() + UTC_TIMESTAMP = auto() + VERSION_SNAPSHOT = auto() + TIMESTAMP_SNAPSHOT = auto() + OPTION = auto() + SINK = auto() + SOURCE = auto() + ANALYZE = auto() + NAMESPACE = auto() + EXPORT = auto() + + # sentinel + HIVE_TOKEN_STREAM = auto() + + +_ALL_TOKEN_TYPES = list(TokenType) +_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)} + + +class Token: + __slots__ = ("token_type", "text", "line", "col", "start", "end", "comments") + + @classmethod + def number(cls, number: int) -> Token: + """Returns a NUMBER token with `number` as its text.""" + return cls(TokenType.NUMBER, str(number)) + + @classmethod + def string(cls, string: str) -> Token: + """Returns a STRING token with `string` as its text.""" + return cls(TokenType.STRING, string) + + @classmethod + def identifier(cls, identifier: str) -> Token: + """Returns an IDENTIFIER token with `identifier` as its text.""" + return cls(TokenType.IDENTIFIER, identifier) + + @classmethod + def var(cls, var: str) -> Token: + """Returns an VAR token with `var` as its text.""" + return cls(TokenType.VAR, var) + + def __init__( + self, + token_type: TokenType, + text: str, + line: int = 1, + col: int = 1, + start: int = 0, + end: int = 0, + comments: t.Optional[t.List[str]] = None, + ) -> None: + """Token initializer. + + Args: + token_type: The TokenType Enum. + text: The text of the token. + line: The line that the token ends on. + col: The column that the token ends on. + start: The start index of the token. + end: The ending index of the token. + comments: The comments to attach to the token. + """ + self.token_type = token_type + self.text = text + self.line = line + self.col = col + self.start = start + self.end = end + self.comments = [] if comments is None else comments + + def __repr__(self) -> str: + attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) + return f"" + + +class _Tokenizer(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + def _convert_quotes(arr: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]: + return dict( + (item, item) if isinstance(item, str) else (item[0], item[1]) + for item in arr + ) + + def _quotes_to_format( + token_type: TokenType, arr: t.List[str | t.Tuple[str, str]] + ) -> t.Dict[str, t.Tuple[str, TokenType]]: + return {k: (v, token_type) for k, v in _convert_quotes(arr).items()} + + klass._QUOTES = _convert_quotes(klass.QUOTES) + klass._IDENTIFIERS = _convert_quotes(klass.IDENTIFIERS) + + klass._FORMAT_STRINGS = { + **{ + p + s: (e, TokenType.NATIONAL_STRING) + for s, e in klass._QUOTES.items() + for p in ("n", "N") + }, + **_quotes_to_format(TokenType.BIT_STRING, klass.BIT_STRINGS), + **_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS), + **_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS), + **_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS), + **_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS), + **_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS), + } + + klass._STRING_ESCAPES = set(klass.STRING_ESCAPES) + klass._ESCAPE_FOLLOW_CHARS = set(klass.ESCAPE_FOLLOW_CHARS) + klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES) + klass._COMMENTS = { + **dict( + (comment, None) + if isinstance(comment, str) + else (comment[0], comment[1]) + for comment in klass.COMMENTS + ), + "{#": "#}", # Ensure Jinja comments are tokenized correctly in all dialects + } + if klass.HINT_START in klass.KEYWORDS: + klass._COMMENTS[klass.HINT_START] = "*/" + + klass._KEYWORD_TRIE = new_trie( + key.upper() + for key in ( + *klass.KEYWORDS, + *klass._COMMENTS, + *klass._QUOTES, + *klass._FORMAT_STRINGS, + ) + if " " in key or any(single in key for single in klass.SINGLE_TOKENS) + ) + + if USE_RS_TOKENIZER: + settings = RsTokenizerSettings( + white_space={ + k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items() + }, + single_tokens={ + k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items() + }, + keywords={ + k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items() + }, + numeric_literals=klass.NUMERIC_LITERALS, + identifiers=klass._IDENTIFIERS, + identifier_escapes=klass._IDENTIFIER_ESCAPES, + string_escapes=klass._STRING_ESCAPES, + quotes=klass._QUOTES, + format_strings={ + k: (v1, _TOKEN_TYPE_TO_INDEX[v2]) + for k, (v1, v2) in klass._FORMAT_STRINGS.items() + }, + has_bit_strings=bool(klass.BIT_STRINGS), + has_hex_strings=bool(klass.HEX_STRINGS), + comments=klass._COMMENTS, + var_single_tokens=klass.VAR_SINGLE_TOKENS, + commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS}, + command_prefix_tokens={ + _TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS + }, + heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER, + string_escapes_allowed_in_raw_strings=klass.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS, + nested_comments=klass.NESTED_COMMENTS, + hint_start=klass.HINT_START, + tokens_preceding_hint={ + _TOKEN_TYPE_TO_INDEX[v] for v in klass.TOKENS_PRECEDING_HINT + }, + escape_follow_chars=klass._ESCAPE_FOLLOW_CHARS, + ) + token_types = RsTokenTypeSettings( + bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING], + break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK], + dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON], + heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING], + raw_string=_TOKEN_TYPE_TO_INDEX[TokenType.RAW_STRING], + hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING], + identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER], + number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER], + parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER], + semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON], + string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING], + var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR], + heredoc_string_alternative=_TOKEN_TYPE_TO_INDEX[ + klass.HEREDOC_STRING_ALTERNATIVE + ], + hint=_TOKEN_TYPE_TO_INDEX[TokenType.HINT], + ) + klass._RS_TOKENIZER = RsTokenizer(settings, token_types) + else: + klass._RS_TOKENIZER = None + + return klass + + +class Tokenizer(metaclass=_Tokenizer): + SINGLE_TOKENS = { + "(": TokenType.L_PAREN, + ")": TokenType.R_PAREN, + "[": TokenType.L_BRACKET, + "]": TokenType.R_BRACKET, + "{": TokenType.L_BRACE, + "}": TokenType.R_BRACE, + "&": TokenType.AMP, + "^": TokenType.CARET, + ":": TokenType.COLON, + ",": TokenType.COMMA, + ".": TokenType.DOT, + "-": TokenType.DASH, + "=": TokenType.EQ, + ">": TokenType.GT, + "<": TokenType.LT, + "%": TokenType.MOD, + "!": TokenType.NOT, + "|": TokenType.PIPE, + "+": TokenType.PLUS, + ";": TokenType.SEMICOLON, + "/": TokenType.SLASH, + "\\": TokenType.BACKSLASH, + "*": TokenType.STAR, + "~": TokenType.TILDA, + "?": TokenType.PLACEHOLDER, + "@": TokenType.PARAMETER, + "#": TokenType.HASH, + # Used for breaking a var like x'y' but nothing else the token type doesn't matter + "'": TokenType.UNKNOWN, + "`": TokenType.UNKNOWN, + '"': TokenType.UNKNOWN, + } + + BIT_STRINGS: t.List[str | t.Tuple[str, str]] = [] + BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] + HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] + RAW_STRINGS: t.List[str | t.Tuple[str, str]] = [] + HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = [] + UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = [] + IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] + QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] + STRING_ESCAPES = ["'"] + VAR_SINGLE_TOKENS: t.Set[str] = set() + ESCAPE_FOLLOW_CHARS: t.List[str] = [] + + # The strings in this list can always be used as escapes, regardless of the surrounding + # identifier delimiters. By default, the closing delimiter is assumed to also act as an + # identifier escape, e.g. if we use double-quotes, then they also act as escapes: "x""" + IDENTIFIER_ESCAPES: t.List[str] = [] + + # Whether the heredoc tags follow the same lexical rules as unquoted identifiers + HEREDOC_TAG_IS_IDENTIFIER = False + + # Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc + HEREDOC_STRING_ALTERNATIVE = TokenType.VAR + + # Whether string escape characters function as such when placed within raw strings + STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = True + + NESTED_COMMENTS = True + + HINT_START = "/*+" + + TOKENS_PRECEDING_HINT = { + TokenType.SELECT, + TokenType.INSERT, + TokenType.UPDATE, + TokenType.DELETE, + } + + # Autofilled + _COMMENTS: t.Dict[str, str] = {} + _FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {} + _IDENTIFIERS: t.Dict[str, str] = {} + _IDENTIFIER_ESCAPES: t.Set[str] = set() + _QUOTES: t.Dict[str, str] = {} + _STRING_ESCAPES: t.Set[str] = set() + _KEYWORD_TRIE: t.Dict = {} + _RS_TOKENIZER: t.Optional[t.Any] = None + _ESCAPE_FOLLOW_CHARS: t.Set[str] = set() + + KEYWORDS: t.Dict[str, TokenType] = { + **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")}, + **{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")}, + **{f"{{{{{postfix}": TokenType.BLOCK_START for postfix in ("+", "-")}, + **{f"{prefix}}}}}": TokenType.BLOCK_END for prefix in ("+", "-")}, + HINT_START: TokenType.HINT, + "&<": TokenType.AMP_LT, + "&>": TokenType.AMP_GT, + "==": TokenType.EQ, + "::": TokenType.DCOLON, + "?::": TokenType.QDCOLON, + "||": TokenType.DPIPE, + "|>": TokenType.PIPE_GT, + ">=": TokenType.GTE, + "<=": TokenType.LTE, + "<>": TokenType.NEQ, + "!=": TokenType.NEQ, + ":=": TokenType.COLON_EQ, + "<=>": TokenType.NULLSAFE_EQ, + "->": TokenType.ARROW, + "->>": TokenType.DARROW, + "=>": TokenType.FARROW, + "#>": TokenType.HASH_ARROW, + "#>>": TokenType.DHASH_ARROW, + "<->": TokenType.LR_ARROW, + "&&": TokenType.DAMP, + "??": TokenType.DQMARK, + "~~~": TokenType.GLOB, + "~~": TokenType.LIKE, + "~~*": TokenType.ILIKE, + "~*": TokenType.IRLIKE, + "-|-": TokenType.ADJACENT, + "ALL": TokenType.ALL, + "AND": TokenType.AND, + "ANTI": TokenType.ANTI, + "ANY": TokenType.ANY, + "ASC": TokenType.ASC, + "AS": TokenType.ALIAS, + "ASOF": TokenType.ASOF, + "AUTOINCREMENT": TokenType.AUTO_INCREMENT, + "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, + "BEGIN": TokenType.BEGIN, + "BETWEEN": TokenType.BETWEEN, + "CACHE": TokenType.CACHE, + "UNCACHE": TokenType.UNCACHE, + "CASE": TokenType.CASE, + "CHARACTER SET": TokenType.CHARACTER_SET, + "CLUSTER BY": TokenType.CLUSTER_BY, + "COLLATE": TokenType.COLLATE, + "COLUMN": TokenType.COLUMN, + "COMMIT": TokenType.COMMIT, + "CONNECT BY": TokenType.CONNECT_BY, + "CONSTRAINT": TokenType.CONSTRAINT, + "COPY": TokenType.COPY, + "CREATE": TokenType.CREATE, + "CROSS": TokenType.CROSS, + "CUBE": TokenType.CUBE, + "CURRENT_DATE": TokenType.CURRENT_DATE, + "CURRENT_SCHEMA": TokenType.CURRENT_SCHEMA, + "CURRENT_TIME": TokenType.CURRENT_TIME, + "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, + "CURRENT_USER": TokenType.CURRENT_USER, + "CURRENT_CATALOG": TokenType.CURRENT_CATALOG, + "DATABASE": TokenType.DATABASE, + "DEFAULT": TokenType.DEFAULT, + "DELETE": TokenType.DELETE, + "DESC": TokenType.DESC, + "DESCRIBE": TokenType.DESCRIBE, + "DISTINCT": TokenType.DISTINCT, + "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, + "DIV": TokenType.DIV, + "DROP": TokenType.DROP, + "ELSE": TokenType.ELSE, + "END": TokenType.END, + "ENUM": TokenType.ENUM, + "ESCAPE": TokenType.ESCAPE, + "EXCEPT": TokenType.EXCEPT, + "EXECUTE": TokenType.EXECUTE, + "EXISTS": TokenType.EXISTS, + "FALSE": TokenType.FALSE, + "FETCH": TokenType.FETCH, + "FILTER": TokenType.FILTER, + "FILE": TokenType.FILE, + "FIRST": TokenType.FIRST, + "FULL": TokenType.FULL, + "FUNCTION": TokenType.FUNCTION, + "FOR": TokenType.FOR, + "FOREIGN KEY": TokenType.FOREIGN_KEY, + "FORMAT": TokenType.FORMAT, + "FROM": TokenType.FROM, + "GEOGRAPHY": TokenType.GEOGRAPHY, + "GEOMETRY": TokenType.GEOMETRY, + "GLOB": TokenType.GLOB, + "GROUP BY": TokenType.GROUP_BY, + "GROUPING SETS": TokenType.GROUPING_SETS, + "HAVING": TokenType.HAVING, + "ILIKE": TokenType.ILIKE, + "IN": TokenType.IN, + "INDEX": TokenType.INDEX, + "INET": TokenType.INET, + "INNER": TokenType.INNER, + "INSERT": TokenType.INSERT, + "INTERVAL": TokenType.INTERVAL, + "INTERSECT": TokenType.INTERSECT, + "INTO": TokenType.INTO, + "IS": TokenType.IS, + "ISNULL": TokenType.ISNULL, + "JOIN": TokenType.JOIN, + "KEEP": TokenType.KEEP, + "KILL": TokenType.KILL, + "LATERAL": TokenType.LATERAL, + "LEFT": TokenType.LEFT, + "LIKE": TokenType.LIKE, + "LIMIT": TokenType.LIMIT, + "LOAD": TokenType.LOAD, + "LOCALTIME": TokenType.LOCALTIME, + "LOCALTIMESTAMP": TokenType.LOCALTIMESTAMP, + "LOCK": TokenType.LOCK, + "MERGE": TokenType.MERGE, + "NAMESPACE": TokenType.NAMESPACE, + "NATURAL": TokenType.NATURAL, + "NEXT": TokenType.NEXT, + "NOT": TokenType.NOT, + "NOTNULL": TokenType.NOTNULL, + "NULL": TokenType.NULL, + "OBJECT": TokenType.OBJECT, + "OFFSET": TokenType.OFFSET, + "ON": TokenType.ON, + "OR": TokenType.OR, + "XOR": TokenType.XOR, + "ORDER BY": TokenType.ORDER_BY, + "ORDINALITY": TokenType.ORDINALITY, + "OUTER": TokenType.OUTER, + "OVER": TokenType.OVER, + "OVERLAPS": TokenType.OVERLAPS, + "OVERWRITE": TokenType.OVERWRITE, + "PARTITION": TokenType.PARTITION, + "PARTITION BY": TokenType.PARTITION_BY, + "PARTITIONED BY": TokenType.PARTITION_BY, + "PARTITIONED_BY": TokenType.PARTITION_BY, + "PERCENT": TokenType.PERCENT, + "PIVOT": TokenType.PIVOT, + "PRAGMA": TokenType.PRAGMA, + "PRIMARY KEY": TokenType.PRIMARY_KEY, + "PROCEDURE": TokenType.PROCEDURE, + "OPERATOR": TokenType.OPERATOR, + "QUALIFY": TokenType.QUALIFY, + "RANGE": TokenType.RANGE, + "RECURSIVE": TokenType.RECURSIVE, + "REGEXP": TokenType.RLIKE, + "RENAME": TokenType.RENAME, + "REPLACE": TokenType.REPLACE, + "RETURNING": TokenType.RETURNING, + "REFERENCES": TokenType.REFERENCES, + "RIGHT": TokenType.RIGHT, + "RLIKE": TokenType.RLIKE, + "ROLLBACK": TokenType.ROLLBACK, + "ROLLUP": TokenType.ROLLUP, + "ROW": TokenType.ROW, + "ROWS": TokenType.ROWS, + "SCHEMA": TokenType.SCHEMA, + "SELECT": TokenType.SELECT, + "SEMI": TokenType.SEMI, + "SESSION": TokenType.SESSION, + "SESSION_USER": TokenType.SESSION_USER, + "SET": TokenType.SET, + "SETTINGS": TokenType.SETTINGS, + "SHOW": TokenType.SHOW, + "SIMILAR TO": TokenType.SIMILAR_TO, + "SOME": TokenType.SOME, + "SORT BY": TokenType.SORT_BY, + "START WITH": TokenType.START_WITH, + "STRAIGHT_JOIN": TokenType.STRAIGHT_JOIN, + "TABLE": TokenType.TABLE, + "TABLESAMPLE": TokenType.TABLE_SAMPLE, + "TEMP": TokenType.TEMPORARY, + "TEMPORARY": TokenType.TEMPORARY, + "THEN": TokenType.THEN, + "TRUE": TokenType.TRUE, + "TRUNCATE": TokenType.TRUNCATE, + "UNION": TokenType.UNION, + "UNKNOWN": TokenType.UNKNOWN, + "UNNEST": TokenType.UNNEST, + "UNPIVOT": TokenType.UNPIVOT, + "UPDATE": TokenType.UPDATE, + "USE": TokenType.USE, + "USING": TokenType.USING, + "UUID": TokenType.UUID, + "VALUES": TokenType.VALUES, + "VIEW": TokenType.VIEW, + "VOLATILE": TokenType.VOLATILE, + "WHEN": TokenType.WHEN, + "WHERE": TokenType.WHERE, + "WINDOW": TokenType.WINDOW, + "WITH": TokenType.WITH, + "APPLY": TokenType.APPLY, + "ARRAY": TokenType.ARRAY, + "BIT": TokenType.BIT, + "BOOL": TokenType.BOOLEAN, + "BOOLEAN": TokenType.BOOLEAN, + "BYTE": TokenType.TINYINT, + "MEDIUMINT": TokenType.MEDIUMINT, + "INT1": TokenType.TINYINT, + "TINYINT": TokenType.TINYINT, + "INT16": TokenType.SMALLINT, + "SHORT": TokenType.SMALLINT, + "SMALLINT": TokenType.SMALLINT, + "HUGEINT": TokenType.INT128, + "UHUGEINT": TokenType.UINT128, + "INT2": TokenType.SMALLINT, + "INTEGER": TokenType.INT, + "INT": TokenType.INT, + "INT4": TokenType.INT, + "INT32": TokenType.INT, + "INT64": TokenType.BIGINT, + "INT128": TokenType.INT128, + "INT256": TokenType.INT256, + "LONG": TokenType.BIGINT, + "BIGINT": TokenType.BIGINT, + "INT8": TokenType.TINYINT, + "UINT": TokenType.UINT, + "UINT128": TokenType.UINT128, + "UINT256": TokenType.UINT256, + "DEC": TokenType.DECIMAL, + "DECIMAL": TokenType.DECIMAL, + "DECIMAL32": TokenType.DECIMAL32, + "DECIMAL64": TokenType.DECIMAL64, + "DECIMAL128": TokenType.DECIMAL128, + "DECIMAL256": TokenType.DECIMAL256, + "DECFLOAT": TokenType.DECFLOAT, + "BIGDECIMAL": TokenType.BIGDECIMAL, + "BIGNUMERIC": TokenType.BIGDECIMAL, + "BIGNUM": TokenType.BIGNUM, + "LIST": TokenType.LIST, + "MAP": TokenType.MAP, + "NULLABLE": TokenType.NULLABLE, + "NUMBER": TokenType.DECIMAL, + "NUMERIC": TokenType.DECIMAL, + "FIXED": TokenType.DECIMAL, + "REAL": TokenType.FLOAT, + "FLOAT": TokenType.FLOAT, + "FLOAT4": TokenType.FLOAT, + "FLOAT8": TokenType.DOUBLE, + "DOUBLE": TokenType.DOUBLE, + "DOUBLE PRECISION": TokenType.DOUBLE, + "JSON": TokenType.JSON, + "JSONB": TokenType.JSONB, + "CHAR": TokenType.CHAR, + "CHARACTER": TokenType.CHAR, + "CHAR VARYING": TokenType.VARCHAR, + "CHARACTER VARYING": TokenType.VARCHAR, + "NCHAR": TokenType.NCHAR, + "VARCHAR": TokenType.VARCHAR, + "VARCHAR2": TokenType.VARCHAR, + "NVARCHAR": TokenType.NVARCHAR, + "NVARCHAR2": TokenType.NVARCHAR, + "BPCHAR": TokenType.BPCHAR, + "STR": TokenType.TEXT, + "STRING": TokenType.TEXT, + "TEXT": TokenType.TEXT, + "LONGTEXT": TokenType.LONGTEXT, + "MEDIUMTEXT": TokenType.MEDIUMTEXT, + "TINYTEXT": TokenType.TINYTEXT, + "CLOB": TokenType.TEXT, + "LONGVARCHAR": TokenType.TEXT, + "BINARY": TokenType.BINARY, + "BLOB": TokenType.VARBINARY, + "LONGBLOB": TokenType.LONGBLOB, + "MEDIUMBLOB": TokenType.MEDIUMBLOB, + "TINYBLOB": TokenType.TINYBLOB, + "BYTEA": TokenType.VARBINARY, + "VARBINARY": TokenType.VARBINARY, + "TIME": TokenType.TIME, + "TIMETZ": TokenType.TIMETZ, + "TIME_NS": TokenType.TIME_NS, + "TIMESTAMP": TokenType.TIMESTAMP, + "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, + "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, + "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, + "TIMESTAMPNTZ": TokenType.TIMESTAMPNTZ, + "TIMESTAMP_NTZ": TokenType.TIMESTAMPNTZ, + "DATE": TokenType.DATE, + "DATETIME": TokenType.DATETIME, + "INT4RANGE": TokenType.INT4RANGE, + "INT4MULTIRANGE": TokenType.INT4MULTIRANGE, + "INT8RANGE": TokenType.INT8RANGE, + "INT8MULTIRANGE": TokenType.INT8MULTIRANGE, + "NUMRANGE": TokenType.NUMRANGE, + "NUMMULTIRANGE": TokenType.NUMMULTIRANGE, + "TSRANGE": TokenType.TSRANGE, + "TSMULTIRANGE": TokenType.TSMULTIRANGE, + "TSTZRANGE": TokenType.TSTZRANGE, + "TSTZMULTIRANGE": TokenType.TSTZMULTIRANGE, + "DATERANGE": TokenType.DATERANGE, + "DATEMULTIRANGE": TokenType.DATEMULTIRANGE, + "UNIQUE": TokenType.UNIQUE, + "VECTOR": TokenType.VECTOR, + "STRUCT": TokenType.STRUCT, + "SEQUENCE": TokenType.SEQUENCE, + "VARIANT": TokenType.VARIANT, + "ALTER": TokenType.ALTER, + "ANALYZE": TokenType.ANALYZE, + "CALL": TokenType.COMMAND, + "COMMENT": TokenType.COMMENT, + "EXPLAIN": TokenType.COMMAND, + "GRANT": TokenType.GRANT, + "REVOKE": TokenType.REVOKE, + "OPTIMIZE": TokenType.COMMAND, + "PREPARE": TokenType.COMMAND, + "VACUUM": TokenType.COMMAND, + "USER-DEFINED": TokenType.USERDEFINED, + "FOR VERSION": TokenType.VERSION_SNAPSHOT, + "FOR TIMESTAMP": TokenType.TIMESTAMP_SNAPSHOT, + } + + WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = { + " ": TokenType.SPACE, + "\t": TokenType.SPACE, + "\n": TokenType.BREAK, + "\r": TokenType.BREAK, + } + + COMMANDS = { + TokenType.COMMAND, + TokenType.EXECUTE, + TokenType.FETCH, + TokenType.SHOW, + TokenType.RENAME, + } + + COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN} + + # Handle numeric literals like in hive (3L = BIGINT) + NUMERIC_LITERALS: t.Dict[str, str] = {} + + COMMENTS = ["--", ("/*", "*/")] + + __slots__ = ( + "sql", + "size", + "tokens", + "dialect", + "use_rs_tokenizer", + "_start", + "_current", + "_line", + "_col", + "_comments", + "_char", + "_end", + "_peek", + "_prev_token_line", + "_rs_dialect_settings", + ) + + def __init__( + self, + dialect: DialectType = None, + use_rs_tokenizer: t.Optional[bool] = None, + **opts: t.Any, + ) -> None: + from bigframes_vendored.sqlglot.dialects import Dialect + + self.dialect = Dialect.get_or_raise(dialect) + + # initialize `use_rs_tokenizer`, and allow it to be overwritten per Tokenizer instance + self.use_rs_tokenizer = ( + use_rs_tokenizer if use_rs_tokenizer is not None else USE_RS_TOKENIZER + ) + + if self.use_rs_tokenizer: + self._rs_dialect_settings = RsTokenizerDialectSettings( + unescaped_sequences=self.dialect.UNESCAPED_SEQUENCES, + identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT, + numbers_can_be_underscore_separated=self.dialect.NUMBERS_CAN_BE_UNDERSCORE_SEPARATED, + ) + + self.reset() + + def reset(self) -> None: + self.sql = "" + self.size = 0 + self.tokens: t.List[Token] = [] + self._start = 0 + self._current = 0 + self._line = 1 + self._col = 0 + self._comments: t.List[str] = [] + + self._char = "" + self._end = False + self._peek = "" + self._prev_token_line = -1 + + def tokenize(self, sql: str) -> t.List[Token]: + """Returns a list of tokens corresponding to the SQL string `sql`.""" + if self.use_rs_tokenizer: + return self.tokenize_rs(sql) + + self.reset() + self.sql = sql + self.size = len(sql) + + try: + self._scan() + except Exception as e: + start = max(self._current - 50, 0) + end = min(self._current + 50, self.size - 1) + context = self.sql[start:end] + raise TokenError(f"Error tokenizing '{context}'") from e + + return self.tokens + + def _scan(self, until: t.Optional[t.Callable] = None) -> None: + while self.size and not self._end: + current = self._current + + # Skip spaces here rather than iteratively calling advance() for performance reasons + while current < self.size: + char = self.sql[current] + + if char.isspace() and (char == " " or char == "\t"): + current += 1 + else: + break + + offset = current - self._current if current > self._current else 1 + + self._start = current + self._advance(offset) + + if not self._char.isspace(): + if self._char.isdigit(): + self._scan_number() + elif self._char in self._IDENTIFIERS: + self._scan_identifier(self._IDENTIFIERS[self._char]) + else: + self._scan_keywords() + + if until and until(): + break + + if self.tokens and self._comments: + self.tokens[-1].comments.extend(self._comments) + + def _chars(self, size: int) -> str: + if size == 1: + return self._char + + start = self._current - 1 + end = start + size + + return self.sql[start:end] if end <= self.size else "" + + def _advance(self, i: int = 1, alnum: bool = False) -> None: + if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: + # Ensures we don't count an extra line if we get a \r\n line break sequence + if not (self._char == "\r" and self._peek == "\n"): + self._col = i + self._line += 1 + else: + self._col += i + + self._current += i + self._end = self._current >= self.size + self._char = self.sql[self._current - 1] + self._peek = "" if self._end else self.sql[self._current] + + if alnum and self._char.isalnum(): + # Here we use local variables instead of attributes for better performance + _col = self._col + _current = self._current + _end = self._end + _peek = self._peek + + while _peek.isalnum(): + _col += 1 + _current += 1 + _end = _current >= self.size + _peek = "" if _end else self.sql[_current] + + self._col = _col + self._current = _current + self._end = _end + self._peek = _peek + self._char = self.sql[_current - 1] + + @property + def _text(self) -> str: + return self.sql[self._start : self._current] + + def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: + self._prev_token_line = self._line + + if self._comments and token_type == TokenType.SEMICOLON and self.tokens: + self.tokens[-1].comments.extend(self._comments) + self._comments = [] + + self.tokens.append( + Token( + token_type, + text=self._text if text is None else text, + line=self._line, + col=self._col, + start=self._start, + end=self._current - 1, + comments=self._comments, + ) + ) + self._comments = [] + + # If we have either a semicolon or a begin token before the command's token, we'll parse + # whatever follows the command's token as a string + if ( + token_type in self.COMMANDS + and self._peek != ";" + and ( + len(self.tokens) == 1 + or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS + ) + ): + start = self._current + tokens = len(self.tokens) + self._scan(lambda: self._peek == ";") + self.tokens = self.tokens[:tokens] + text = self.sql[start : self._current].strip() + if text: + self._add(TokenType.STRING, text) + + def _scan_keywords(self) -> None: + size = 0 + word = None + chars = self._text + char = chars + prev_space = False + skip = False + trie = self._KEYWORD_TRIE + single_token = char in self.SINGLE_TOKENS + + while chars: + if skip: + result = TrieResult.PREFIX + else: + result, trie = in_trie(trie, char.upper()) + + if result == TrieResult.FAILED: + break + if result == TrieResult.EXISTS: + word = chars + + end = self._current + size + size += 1 + + if end < self.size: + char = self.sql[end] + single_token = single_token or char in self.SINGLE_TOKENS + is_space = char.isspace() + + if not is_space or not prev_space: + if is_space: + char = " " + chars += char + prev_space = is_space + skip = False + else: + skip = True + else: + char = "" + break + + if word: + if self._scan_string(word): + return + if self._scan_comment(word): + return + if prev_space or single_token or not char: + self._advance(size - 1) + word = word.upper() + self._add(self.KEYWORDS[word], text=word) + return + + if self._char in self.SINGLE_TOKENS: + self._add(self.SINGLE_TOKENS[self._char], text=self._char) + return + + self._scan_var() + + def _scan_comment(self, comment_start: str) -> bool: + if comment_start not in self._COMMENTS: + return False + + comment_start_line = self._line + comment_start_size = len(comment_start) + comment_end = self._COMMENTS[comment_start] + + if comment_end: + # Skip the comment's start delimiter + self._advance(comment_start_size) + + comment_count = 1 + comment_end_size = len(comment_end) + + while not self._end: + if self._chars(comment_end_size) == comment_end: + comment_count -= 1 + if not comment_count: + break + + self._advance(alnum=True) + + # Nested comments are allowed by some dialects, e.g. databricks, duckdb, postgres + if ( + self.NESTED_COMMENTS + and not self._end + and self._chars(comment_end_size) == comment_start + ): + self._advance(comment_start_size) + comment_count += 1 + + self._comments.append( + self._text[comment_start_size : -comment_end_size + 1] + ) + self._advance(comment_end_size - 1) + else: + while ( + not self._end + and self.WHITE_SPACE.get(self._peek) is not TokenType.BREAK + ): + self._advance(alnum=True) + self._comments.append(self._text[comment_start_size:]) + + if ( + comment_start == self.HINT_START + and self.tokens + and self.tokens[-1].token_type in self.TOKENS_PRECEDING_HINT + ): + self._add(TokenType.HINT) + + # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. + # Multiple consecutive comments are preserved by appending them to the current comments list. + if comment_start_line == self._prev_token_line: + self.tokens[-1].comments.extend(self._comments) + self._comments = [] + self._prev_token_line = self._line + + return True + + def _scan_number(self) -> None: + if self._char == "0": + peek = self._peek.upper() + if peek == "B": + return ( + self._scan_bits() + if self.BIT_STRINGS + else self._add(TokenType.NUMBER) + ) + elif peek == "X": + return ( + self._scan_hex() + if self.HEX_STRINGS + else self._add(TokenType.NUMBER) + ) + + decimal = False + scientific = 0 + + while True: + if self._peek.isdigit(): + self._advance() + elif self._peek == "." and not decimal: + if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER: + return self._add(TokenType.NUMBER) + decimal = True + self._advance() + elif self._peek in ("-", "+") and scientific == 1: + # Only consume +/- if followed by a digit + if ( + self._current + 1 < self.size + and self.sql[self._current + 1].isdigit() + ): + scientific += 1 + self._advance() + else: + return self._add(TokenType.NUMBER) + elif self._peek.upper() == "E" and not scientific: + scientific += 1 + self._advance() + elif self._peek == "_" and self.dialect.NUMBERS_CAN_BE_UNDERSCORE_SEPARATED: + self._advance() + elif self._peek.isidentifier(): + number_text = self._text + literal = "" + + while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: + literal += self._peek + self._advance() + + token_type = self.KEYWORDS.get( + self.NUMERIC_LITERALS.get(literal.upper(), "") + ) + + if token_type: + self._add(TokenType.NUMBER, number_text) + self._add(TokenType.DCOLON, "::") + return self._add(token_type, literal) + elif self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT: + return self._add(TokenType.VAR) + + self._advance(-len(literal)) + return self._add(TokenType.NUMBER, number_text) + else: + return self._add(TokenType.NUMBER) + + def _scan_bits(self) -> None: + self._advance() + value = self._extract_value() + try: + # If `value` can't be converted to a binary, fallback to tokenizing it as an identifier + int(value, 2) + self._add(TokenType.BIT_STRING, value[2:]) # Drop the 0b + except ValueError: + self._add(TokenType.IDENTIFIER) + + def _scan_hex(self) -> None: + self._advance() + value = self._extract_value() + try: + # If `value` can't be converted to a hex, fallback to tokenizing it as an identifier + int(value, 16) + self._add(TokenType.HEX_STRING, value[2:]) # Drop the 0x + except ValueError: + self._add(TokenType.IDENTIFIER) + + def _extract_value(self) -> str: + while True: + char = self._peek.strip() + if char and char not in self.SINGLE_TOKENS: + self._advance(alnum=True) + else: + break + + return self._text + + def _scan_string(self, start: str) -> bool: + base = None + token_type = TokenType.STRING + + if start in self._QUOTES: + end = self._QUOTES[start] + elif start in self._FORMAT_STRINGS: + end, token_type = self._FORMAT_STRINGS[start] + + if token_type == TokenType.HEX_STRING: + base = 16 + elif token_type == TokenType.BIT_STRING: + base = 2 + elif token_type == TokenType.HEREDOC_STRING: + self._advance() + + if self._char == end: + tag = "" + else: + tag = self._extract_string( + end, + raw_string=True, + raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER, + ) + + if ( + tag + and self.HEREDOC_TAG_IS_IDENTIFIER + and (self._end or tag.isdigit() or any(c.isspace() for c in tag)) + ): + if not self._end: + self._advance(-1) + + self._advance(-len(tag)) + self._add(self.HEREDOC_STRING_ALTERNATIVE) + return True + + end = f"{start}{tag}{end}" + else: + return False + + self._advance(len(start)) + text = self._extract_string(end, raw_string=token_type == TokenType.RAW_STRING) + + if base and text: + try: + int(text, base) + except Exception: + raise TokenError( + f"Numeric string contains invalid characters from {self._line}:{self._start}" + ) + + self._add(token_type, text) + return True + + def _scan_identifier(self, identifier_end: str) -> None: + self._advance() + text = self._extract_string( + identifier_end, escapes=self._IDENTIFIER_ESCAPES | {identifier_end} + ) + self._add(TokenType.IDENTIFIER, text) + + def _scan_var(self) -> None: + while True: + char = self._peek.strip() + if char and ( + char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS + ): + self._advance(alnum=True) + else: + break + + self._add( + TokenType.VAR + if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER + else self.KEYWORDS.get(self._text.upper(), TokenType.VAR) + ) + + def _extract_string( + self, + delimiter: str, + escapes: t.Optional[t.Set[str]] = None, + raw_string: bool = False, + raise_unmatched: bool = True, + ) -> str: + text = "" + delim_size = len(delimiter) + escapes = self._STRING_ESCAPES if escapes is None else escapes + + while True: + if ( + not raw_string + and self.dialect.UNESCAPED_SEQUENCES + and self._peek + and self._char in self.STRING_ESCAPES + ): + unescaped_sequence = self.dialect.UNESCAPED_SEQUENCES.get( + self._char + self._peek + ) + if unescaped_sequence: + self._advance(2) + text += unescaped_sequence + continue + + is_valid_custom_escape = ( + self.ESCAPE_FOLLOW_CHARS + and self._char == "\\" + and self._peek not in self.ESCAPE_FOLLOW_CHARS + ) + + if ( + (self.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS or not raw_string) + and self._char in escapes + and ( + self._peek == delimiter + or self._peek in escapes + or is_valid_custom_escape + ) + and (self._char not in self._QUOTES or self._char == self._peek) + ): + if self._peek == delimiter: + text += self._peek + elif is_valid_custom_escape and self._char != self._peek: + text += self._peek + else: + text += self._char + self._peek + + if self._current + 1 < self.size: + self._advance(2) + else: + raise TokenError( + f"Missing {delimiter} from {self._line}:{self._current}" + ) + else: + if self._chars(delim_size) == delimiter: + if delim_size > 1: + self._advance(delim_size - 1) + break + + if self._end: + if not raise_unmatched: + return text + self._char + + raise TokenError( + f"Missing {delimiter} from {self._line}:{self._start}" + ) + + current = self._current - 1 + self._advance(alnum=True) + text += self.sql[current : self._current - 1] + + return text + + def tokenize_rs(self, sql: str) -> t.List[Token]: + if not self._RS_TOKENIZER: + raise SqlglotError("Rust tokenizer is not available") + + tokens, error_msg = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings) + for token in tokens: + token.token_type = _ALL_TOKEN_TYPES[token.token_type_index] + + # Setting this here so partial token lists can be inspected even if there is a failure + self.tokens = tokens + + if error_msg is not None: + raise TokenError(error_msg) + + return tokens diff --git a/third_party/bigframes_vendored/sqlglot/transforms.py b/third_party/bigframes_vendored/sqlglot/transforms.py new file mode 100644 index 0000000000..2317338e5c --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/transforms.py @@ -0,0 +1,1125 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import expressions as exp +from bigframes_vendored.sqlglot.errors import UnsupportedError +from bigframes_vendored.sqlglot.helper import find_new_name, name_sequence, seq_get + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + from bigframes_vendored.sqlglot.generator import Generator + + +def preprocess( + transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], + generator: t.Optional[t.Callable[[Generator, exp.Expression], str]] = None, +) -> t.Callable[[Generator, exp.Expression], str]: + """ + Creates a new transform by chaining a sequence of transformations and converts the resulting + expression to SQL, using either the "_sql" method corresponding to the resulting expression, + or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). + + Args: + transforms: sequence of transform functions. These will be called in order. + + Returns: + Function that can be used as a generator transform. + """ + + def _to_sql(self, expression: exp.Expression) -> str: + expression_type = type(expression) + + try: + expression = transforms[0](expression) + for transform in transforms[1:]: + expression = transform(expression) + except UnsupportedError as unsupported_error: + self.unsupported(str(unsupported_error)) + + if generator: + return generator(self, expression) + + _sql_handler = getattr(self, expression.key + "_sql", None) + if _sql_handler: + return _sql_handler(expression) + + transforms_handler = self.TRANSFORMS.get(type(expression)) + if transforms_handler: + if expression_type is type(expression): + if isinstance(expression, exp.Func): + return self.function_fallback_sql(expression) + + # Ensures we don't enter an infinite loop. This can happen when the original expression + # has the same type as the final expression and there's no _sql method available for it, + # because then it'd re-enter _to_sql. + raise ValueError( + f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." + ) + + return transforms_handler(self, expression) + + raise ValueError( + f"Unsupported expression type {expression.__class__.__name__}." + ) + + return _to_sql + + +def unnest_generate_date_array_using_recursive_cte( + expression: exp.Expression, +) -> exp.Expression: + if isinstance(expression, exp.Select): + count = 0 + recursive_ctes = [] + + for unnest in expression.find_all(exp.Unnest): + if ( + not isinstance(unnest.parent, (exp.From, exp.Join)) + or len(unnest.expressions) != 1 + or not isinstance(unnest.expressions[0], exp.GenerateDateArray) + ): + continue + + generate_date_array = unnest.expressions[0] + start = generate_date_array.args.get("start") + end = generate_date_array.args.get("end") + step = generate_date_array.args.get("step") + + if not start or not end or not isinstance(step, exp.Interval): + continue + + alias = unnest.args.get("alias") + column_name = ( + alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" + ) + + start = exp.cast(start, "date") + date_add = exp.func( + "date_add", + column_name, + exp.Literal.number(step.name), + step.args.get("unit"), + ) + cast_date_add = exp.cast(date_add, "date") + + cte_name = "_generated_dates" + (f"_{count}" if count else "") + + base_query = exp.select(start.as_(column_name)) + recursive_query = ( + exp.select(cast_date_add) + .from_(cte_name) + .where(cast_date_add <= exp.cast(end, "date")) + ) + cte_query = base_query.union(recursive_query, distinct=False) + + generate_dates_query = exp.select(column_name).from_(cte_name) + unnest.replace(generate_dates_query.subquery(cte_name)) + + recursive_ctes.append( + exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) + ) + count += 1 + + if recursive_ctes: + with_expression = expression.args.get("with_") or exp.With() + with_expression.set("recursive", True) + with_expression.set( + "expressions", [*recursive_ctes, *with_expression.expressions] + ) + expression.set("with_", with_expression) + + return expression + + +def unnest_generate_series(expression: exp.Expression) -> exp.Expression: + """Unnests GENERATE_SERIES or SEQUENCE table references.""" + this = expression.this + if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): + unnest = exp.Unnest(expressions=[this]) + if expression.alias: + return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) + + return unnest + + return expression + + +def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: + """ + Convert SELECT DISTINCT ON statements to a subquery with a window function. + + This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. + + Args: + expression: the expression that will be transformed. + + Returns: + The transformed expression. + """ + if ( + isinstance(expression, exp.Select) + and expression.args.get("distinct") + and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) + ): + row_number_window_alias = find_new_name(expression.named_selects, "_row_number") + + distinct_cols = expression.args["distinct"].pop().args["on"].expressions + window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) + + order = expression.args.get("order") + if order: + window.set("order", order.pop()) + else: + window.set( + "order", exp.Order(expressions=[c.copy() for c in distinct_cols]) + ) + + window = exp.alias_(window, row_number_window_alias) + expression.select(window, copy=False) + + # We add aliases to the projections so that we can safely reference them in the outer query + new_selects = [] + taken_names = {row_number_window_alias} + for select in expression.selects[:-1]: + if select.is_star: + new_selects = [exp.Star()] + break + + if not isinstance(select, exp.Alias): + alias = find_new_name(taken_names, select.output_name or "_col") + quoted = ( + select.this.args.get("quoted") + if isinstance(select, exp.Column) + else None + ) + select = select.replace(exp.alias_(select, alias, quoted=quoted)) + + taken_names.add(select.output_name) + new_selects.append(select.args["alias"]) + + return ( + exp.select(*new_selects, copy=False) + .from_(expression.subquery("_t", copy=False), copy=False) + .where(exp.column(row_number_window_alias).eq(1), copy=False) + ) + + return expression + + +def eliminate_qualify(expression: exp.Expression) -> exp.Expression: + """ + Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. + + The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: + https://docs.snowflake.com/en/sql-reference/constructs/qualify + + Some dialects don't support window functions in the WHERE clause, so we need to include them as + projections in the subquery, in order to refer to them in the outer filter using aliases. Also, + if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, + otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a + newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the + corresponding expression to avoid creating invalid column references. + """ + if isinstance(expression, exp.Select) and expression.args.get("qualify"): + taken = set(expression.named_selects) + for select in expression.selects: + if not select.alias_or_name: + alias = find_new_name(taken, "_c") + select.replace(exp.alias_(select, alias)) + taken.add(alias) + + def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: + alias_or_name = select.alias_or_name + identifier = select.args.get("alias") or select.this + if isinstance(identifier, exp.Identifier): + return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) + return alias_or_name + + outer_selects = exp.select( + *list(map(_select_alias_or_name, expression.selects)) + ) + qualify_filters = expression.args["qualify"].pop().this + expression_by_alias = { + select.alias: select.this + for select in expression.selects + if isinstance(select, exp.Alias) + } + + select_candidates = ( + exp.Window if expression.is_star else (exp.Window, exp.Column) + ) + for select_candidate in list(qualify_filters.find_all(select_candidates)): + if isinstance(select_candidate, exp.Window): + if expression_by_alias: + for column in select_candidate.find_all(exp.Column): + expr = expression_by_alias.get(column.name) + if expr: + column.replace(expr) + + alias = find_new_name(expression.named_selects, "_w") + expression.select(exp.alias_(select_candidate, alias), copy=False) + column = exp.column(alias) + + if isinstance(select_candidate.parent, exp.Qualify): + qualify_filters = column + else: + select_candidate.replace(column) + elif select_candidate.name not in expression.named_selects: + expression.select(select_candidate.copy(), copy=False) + + return outer_selects.from_( + expression.subquery(alias="_t", copy=False), copy=False + ).where(qualify_filters, copy=False) + + return expression + + +def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: + """ + Some dialects only allow the precision for parameterized types to be defined in the DDL and not in + other expressions. This transforms removes the precision from parameterized types in expressions. + """ + for node in expression.find_all(exp.DataType): + node.set( + "expressions", + [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)], + ) + + return expression + + +def unqualify_unnest(expression: exp.Expression) -> exp.Expression: + """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" + from bigframes_vendored.sqlglot.optimizer.scope import find_all_in_scope + + if isinstance(expression, exp.Select): + unnest_aliases = { + unnest.alias + for unnest in find_all_in_scope(expression, exp.Unnest) + if isinstance(unnest.parent, (exp.From, exp.Join)) + } + if unnest_aliases: + for column in expression.find_all(exp.Column): + leftmost_part = column.parts[0] + if ( + leftmost_part.arg_key != "this" + and leftmost_part.this in unnest_aliases + ): + leftmost_part.pop() + + return expression + + +def unnest_to_explode( + expression: exp.Expression, + unnest_using_arrays_zip: bool = True, +) -> exp.Expression: + """Convert cross join unnest into lateral view explode.""" + + def _unnest_zip_exprs( + u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool + ) -> t.List[exp.Expression]: + if has_multi_expr: + if not unnest_using_arrays_zip: + raise UnsupportedError( + "Cannot transpile UNNEST with multiple input arrays" + ) + + # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions + zip_exprs: t.List[exp.Expression] = [ + exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) + ] + u.set("expressions", zip_exprs) + return zip_exprs + return unnest_exprs + + def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: + if u.args.get("offset"): + return exp.Posexplode + return exp.Inline if has_multi_expr else exp.Explode + + if isinstance(expression, exp.Select): + from_ = expression.args.get("from_") + + if from_ and isinstance(from_.this, exp.Unnest): + unnest = from_.this + alias = unnest.args.get("alias") + exprs = unnest.expressions + has_multi_expr = len(exprs) > 1 + this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr) + + columns = alias.columns if alias else [] + offset = unnest.args.get("offset") + if offset: + columns.insert( + 0, + offset + if isinstance(offset, exp.Identifier) + else exp.to_identifier("pos"), + ) + + unnest.replace( + exp.Table( + this=_udtf_type(unnest, has_multi_expr)(this=this), + alias=exp.TableAlias(this=alias.this, columns=columns) + if alias + else None, + ) + ) + + joins = expression.args.get("joins") or [] + for join in list(joins): + join_expr = join.this + + is_lateral = isinstance(join_expr, exp.Lateral) + + unnest = join_expr.this if is_lateral else join_expr + + if isinstance(unnest, exp.Unnest): + if is_lateral: + alias = join_expr.args.get("alias") + else: + alias = unnest.args.get("alias") + exprs = unnest.expressions + # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here + has_multi_expr = len(exprs) > 1 + exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) + + joins.remove(join) + + alias_cols = alias.columns if alias else [] + + # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases + # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. + # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html + + if not has_multi_expr and len(alias_cols) not in (1, 2): + raise UnsupportedError( + "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" + ) + + offset = unnest.args.get("offset") + if offset: + alias_cols.insert( + 0, + offset + if isinstance(offset, exp.Identifier) + else exp.to_identifier("pos"), + ) + + for e, column in zip(exprs, alias_cols): + expression.append( + "laterals", + exp.Lateral( + this=_udtf_type(unnest, has_multi_expr)(this=e), + view=True, + alias=exp.TableAlias( + this=alias.this, # type: ignore + columns=alias_cols, + ), + ), + ) + + return expression + + +def explode_projection_to_unnest( + index_offset: int = 0, +) -> t.Callable[[exp.Expression], exp.Expression]: + """Convert explode/posexplode projections into unnests.""" + + def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Select): + from bigframes_vendored.sqlglot.optimizer.scope import Scope + + taken_select_names = set(expression.named_selects) + taken_source_names = {name for name, _ in Scope(expression).references} + + def new_name(names: t.Set[str], name: str) -> str: + name = find_new_name(names, name) + names.add(name) + return name + + arrays: t.List[exp.Condition] = [] + series_alias = new_name(taken_select_names, "pos") + series = exp.alias_( + exp.Unnest( + expressions=[ + exp.GenerateSeries(start=exp.Literal.number(index_offset)) + ] + ), + new_name(taken_source_names, "_u"), + table=[series_alias], + ) + + # we use list here because expression.selects is mutated inside the loop + for select in list(expression.selects): + explode = select.find(exp.Explode) + + if explode: + pos_alias = "" + explode_alias = "" + + if isinstance(select, exp.Alias): + explode_alias = select.args["alias"] + alias = select + elif isinstance(select, exp.Aliases): + pos_alias = select.aliases[0] + explode_alias = select.aliases[1] + alias = select.replace(exp.alias_(select.this, "", copy=False)) + else: + alias = select.replace(exp.alias_(select, "")) + explode = alias.find(exp.Explode) + assert explode + + is_posexplode = isinstance(explode, exp.Posexplode) + explode_arg = explode.this + + if isinstance(explode, exp.ExplodeOuter): + bracket = explode_arg[0] + bracket.set("safe", True) + bracket.set("offset", True) + explode_arg = exp.func( + "IF", + exp.func( + "ARRAY_SIZE", + exp.func("COALESCE", explode_arg, exp.Array()), + ).eq(0), + exp.array(bracket, copy=False), + explode_arg, + ) + + # This ensures that we won't use [POS]EXPLODE's argument as a new selection + if isinstance(explode_arg, exp.Column): + taken_select_names.add(explode_arg.output_name) + + unnest_source_alias = new_name(taken_source_names, "_u") + + if not explode_alias: + explode_alias = new_name(taken_select_names, "col") + + if is_posexplode: + pos_alias = new_name(taken_select_names, "pos") + + if not pos_alias: + pos_alias = new_name(taken_select_names, "pos") + + alias.set("alias", exp.to_identifier(explode_alias)) + + series_table_alias = series.args["alias"].this + column = exp.If( + this=exp.column(series_alias, table=series_table_alias).eq( + exp.column(pos_alias, table=unnest_source_alias) + ), + true=exp.column(explode_alias, table=unnest_source_alias), + ) + + explode.replace(column) + + if is_posexplode: + expressions = expression.expressions + expressions.insert( + expressions.index(alias) + 1, + exp.If( + this=exp.column( + series_alias, table=series_table_alias + ).eq(exp.column(pos_alias, table=unnest_source_alias)), + true=exp.column(pos_alias, table=unnest_source_alias), + ).as_(pos_alias), + ) + expression.set("expressions", expressions) + + if not arrays: + if expression.args.get("from_"): + expression.join(series, copy=False, join_type="CROSS") + else: + expression.from_(series, copy=False) + + size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) + arrays.append(size) + + # trino doesn't support left join unnest with on conditions + # if it did, this would be much simpler + expression.join( + exp.alias_( + exp.Unnest( + expressions=[explode_arg.copy()], + offset=exp.to_identifier(pos_alias), + ), + unnest_source_alias, + table=[explode_alias], + ), + join_type="CROSS", + copy=False, + ) + + if index_offset != 1: + size = size - 1 + + expression.where( + exp.column(series_alias, table=series_table_alias) + .eq(exp.column(pos_alias, table=unnest_source_alias)) + .or_( + ( + exp.column(series_alias, table=series_table_alias) + > size + ).and_( + exp.column(pos_alias, table=unnest_source_alias).eq( + size + ) + ) + ), + copy=False, + ) + + if arrays: + end: exp.Condition = exp.Greatest( + this=arrays[0], expressions=arrays[1:] + ) + + if index_offset != 1: + end = end - (1 - index_offset) + series.expressions[0].set("end", end) + + return expression + + return _explode_projection_to_unnest + + +def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: + """Transforms percentiles by adding a WITHIN GROUP clause to them.""" + if ( + isinstance(expression, exp.PERCENTILES) + and not isinstance(expression.parent, exp.WithinGroup) + and expression.expression + ): + column = expression.this.pop() + expression.set("this", expression.expression.pop()) + order = exp.Order(expressions=[exp.Ordered(this=column)]) + expression = exp.WithinGroup(this=expression, expression=order) + + return expression + + +def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: + """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" + if ( + isinstance(expression, exp.WithinGroup) + and isinstance(expression.this, exp.PERCENTILES) + and isinstance(expression.expression, exp.Order) + ): + quantile = expression.this.this + input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this + return expression.replace( + exp.ApproxQuantile(this=input_value, quantile=quantile) + ) + + return expression + + +def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: + """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" + if isinstance(expression, exp.With) and expression.recursive: + next_name = name_sequence("_c_") + + for cte in expression.expressions: + if not cte.args["alias"].columns: + query = cte.this + if isinstance(query, exp.SetOperation): + query = query.this + + cte.args["alias"].set( + "columns", + [ + exp.to_identifier(s.alias_or_name or next_name()) + for s in query.selects + ], + ) + + return expression + + +def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: + """Replace 'epoch' in casts by the equivalent date literal.""" + if ( + isinstance(expression, (exp.Cast, exp.TryCast)) + and expression.name.lower() == "epoch" + and expression.to.this in exp.DataType.TEMPORAL_TYPES + ): + expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) + + return expression + + +def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: + """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" + if isinstance(expression, exp.Select): + for join in expression.args.get("joins") or []: + on = join.args.get("on") + if on and join.kind in ("SEMI", "ANTI"): + subquery = exp.select("1").from_(join.this).where(on) + exists = exp.Exists(this=subquery) + if join.kind == "ANTI": + exists = exists.not_(copy=False) + + join.pop() + expression.where(exists, copy=False) + + return expression + + +def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: + """ + Converts a query with a FULL OUTER join to a union of identical queries that + use LEFT/RIGHT OUTER joins instead. This transformation currently only works + for queries that have a single FULL OUTER join. + """ + if isinstance(expression, exp.Select): + full_outer_joins = [ + (index, join) + for index, join in enumerate(expression.args.get("joins") or []) + if join.side == "FULL" + ] + + if len(full_outer_joins) == 1: + expression_copy = expression.copy() + expression.set("limit", None) + index, full_outer_join = full_outer_joins[0] + + tables = ( + expression.args["from_"].alias_or_name, + full_outer_join.alias_or_name, + ) + join_conditions = full_outer_join.args.get("on") or exp.and_( + *[ + exp.column(col, tables[0]).eq(exp.column(col, tables[1])) + for col in full_outer_join.args.get("using") + ] + ) + + full_outer_join.set("side", "left") + anti_join_clause = ( + exp.select("1").from_(expression.args["from_"]).where(join_conditions) + ) + expression_copy.args["joins"][index].set("side", "right") + expression_copy = expression_copy.where( + exp.Exists(this=anti_join_clause).not_() + ) + expression_copy.set("with_", None) # remove CTEs from RIGHT side + expression.set("order", None) # remove order by from LEFT side + + return exp.union(expression, expression_copy, copy=False, distinct=False) + + return expression + + +def move_ctes_to_top_level(expression: E) -> E: + """ + Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be + defined at the top-level, so for example queries like: + + SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq + + are invalid in those dialects. This transformation can be used to ensure all CTEs are + moved to the top level so that the final SQL code is valid from a syntax standpoint. + + TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). + """ + top_level_with = expression.args.get("with_") + for inner_with in expression.find_all(exp.With): + if inner_with.parent is expression: + continue + + if not top_level_with: + top_level_with = inner_with.pop() + expression.set("with_", top_level_with) + else: + if inner_with.recursive: + top_level_with.set("recursive", True) + + parent_cte = inner_with.find_ancestor(exp.CTE) + inner_with.pop() + + if parent_cte: + i = top_level_with.expressions.index(parent_cte) + top_level_with.expressions[i:i] = inner_with.expressions + top_level_with.set("expressions", top_level_with.expressions) + else: + top_level_with.set( + "expressions", top_level_with.expressions + inner_with.expressions + ) + + return expression + + +def ensure_bools(expression: exp.Expression) -> exp.Expression: + """Converts numeric values used in conditions into explicit boolean expressions.""" + from bigframes_vendored.sqlglot.optimizer.canonicalize import ensure_bools + + def _ensure_bool(node: exp.Expression) -> None: + if ( + node.is_number + or ( + not isinstance(node, exp.SubqueryPredicate) + and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) + ) + or (isinstance(node, exp.Column) and not node.type) + ): + node.replace(node.neq(0)) + + for node in expression.walk(): + ensure_bools(node, _ensure_bool) + + return expression + + +def unqualify_columns(expression: exp.Expression) -> exp.Expression: + for column in expression.find_all(exp.Column): + # We only wanna pop off the table, db, catalog args + for part in column.parts[:-1]: + part.pop() + + return expression + + +def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: + assert isinstance(expression, exp.Create) + for constraint in expression.find_all(exp.UniqueColumnConstraint): + if constraint.parent: + constraint.parent.pop() + + return expression + + +def ctas_with_tmp_tables_to_create_tmp_view( + expression: exp.Expression, + tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, +) -> exp.Expression: + assert isinstance(expression, exp.Create) + properties = expression.args.get("properties") + temporary = any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) + + # CTAS with temp tables map to CREATE TEMPORARY VIEW + if expression.kind == "TABLE" and temporary: + if expression.expression: + return exp.Create( + kind="TEMPORARY VIEW", + this=expression.this, + expression=expression.expression, + ) + return tmp_storage_provider(expression) + + return expression + + +def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: + """ + In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the + PARTITIONED BY value is an array of column names, they are transformed into a schema. + The corresponding columns are removed from the create statement. + """ + assert isinstance(expression, exp.Create) + has_schema = isinstance(expression.this, exp.Schema) + is_partitionable = expression.kind in {"TABLE", "VIEW"} + + if has_schema and is_partitionable: + prop = expression.find(exp.PartitionedByProperty) + if prop and prop.this and not isinstance(prop.this, exp.Schema): + schema = expression.this + columns = {v.name.upper() for v in prop.this.expressions} + partitions = [ + col for col in schema.expressions if col.name.upper() in columns + ] + schema.set( + "expressions", [e for e in schema.expressions if e not in partitions] + ) + prop.replace( + exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)) + ) + expression.set("this", schema) + + return expression + + +def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: + """ + Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. + + Currently, SQLGlot uses the DATASOURCE format for Spark 3. + """ + assert isinstance(expression, exp.Create) + prop = expression.find(exp.PartitionedByProperty) + if ( + prop + and prop.this + and isinstance(prop.this, exp.Schema) + and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) + ): + prop_this = exp.Tuple( + expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] + ) + schema = expression.this + for e in prop.this.expressions: + schema.append("expressions", e) + prop.set("this", prop_this) + + return expression + + +def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: + """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" + if isinstance(expression, exp.Struct): + expression.set( + "expressions", + [ + exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e + for e in expression.expressions + ], + ) + + return expression + + +def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: + """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178 + + 1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax. + + 2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view. + + The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query. + + You cannot use the (+) operator to outer-join a table to itself, although self joins are valid. + + The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator. + + A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator. + + A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression. + + A WHERE condition cannot compare any column marked with the (+) operator with a subquery. + + -- example with WHERE + SELECT d.department_name, sum(e.salary) as total_salary + FROM departments d, employees e + WHERE e.department_id(+) = d.department_id + group by department_name + + -- example of left correlation in select + SELECT d.department_name, ( + SELECT SUM(e.salary) + FROM employees e + WHERE e.department_id(+) = d.department_id) AS total_salary + FROM departments d; + + -- example of left correlation in from + SELECT d.department_name, t.total_salary + FROM departments d, ( + SELECT SUM(e.salary) AS total_salary + FROM employees e + WHERE e.department_id(+) = d.department_id + ) t + """ + + from collections import defaultdict + + from bigframes_vendored.sqlglot.optimizer.normalize import normalize, normalized + from bigframes_vendored.sqlglot.optimizer.scope import traverse_scope + + # we go in reverse to check the main query for left correlation + for scope in reversed(traverse_scope(expression)): + query = scope.expression + + where = query.args.get("where") + joins = query.args.get("joins", []) + + if not where or not any( + c.args.get("join_mark") for c in where.find_all(exp.Column) + ): + continue + + # knockout: we do not support left correlation (see point 2) + assert not scope.is_correlated_subquery, "Correlated queries are not supported" + + # make sure we have AND of ORs to have clear join terms + where = normalize(where.this) + assert normalized(where), "Cannot normalize JOIN predicates" + + joins_ons = defaultdict(list) # dict of {name: list of join AND conditions} + for cond in [where] if not isinstance(where, exp.And) else where.flatten(): + join_cols = [ + col for col in cond.find_all(exp.Column) if col.args.get("join_mark") + ] + + left_join_table = set(col.table for col in join_cols) + if not left_join_table: + continue + + assert not ( + len(left_join_table) > 1 + ), "Cannot combine JOIN predicates from different tables" + + for col in join_cols: + col.set("join_mark", False) + + joins_ons[left_join_table.pop()].append(cond) + + old_joins = {join.alias_or_name: join for join in joins} + new_joins = {} + query_from = query.args["from_"] + + for table, predicates in joins_ons.items(): + join_what = old_joins.get(table, query_from).this.copy() + new_joins[join_what.alias_or_name] = exp.Join( + this=join_what, on=exp.and_(*predicates), kind="LEFT" + ) + + for p in predicates: + while isinstance(p.parent, exp.Paren): + p.parent.replace(p) + + parent = p.parent + p.pop() + if isinstance(parent, exp.Binary): + parent.replace(parent.right if parent.left is None else parent.left) + elif isinstance(parent, exp.Where): + parent.pop() + + if query_from.alias_or_name in new_joins: + only_old_joins = old_joins.keys() - new_joins.keys() + assert ( + len(only_old_joins) >= 1 + ), "Cannot determine which table to use in the new FROM clause" + + new_from_name = list(only_old_joins)[0] + query.set("from_", exp.From(this=old_joins[new_from_name].this)) + + if new_joins: + for n, j in old_joins.items(): # preserve any other joins + if n not in new_joins and n != query.args["from_"].name: + if not j.kind: + j.set("kind", "CROSS") + new_joins[n] = j + query.set("joins", list(new_joins.values())) + + return expression + + +def any_to_exists(expression: exp.Expression) -> exp.Expression: + """ + Transform ANY operator to Spark's EXISTS + + For example, + - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) + - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) + + Both ANY and EXISTS accept queries but currently only array expressions are supported for this + transformation + """ + if isinstance(expression, exp.Select): + for any_expr in expression.find_all(exp.Any): + this = any_expr.this + if isinstance(this, exp.Query) or isinstance( + any_expr.parent, (exp.Like, exp.ILike) + ): + continue + + binop = any_expr.parent + if isinstance(binop, exp.Binary): + lambda_arg = exp.to_identifier("x") + any_expr.replace(lambda_arg) + lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) + binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) + + return expression + + +def eliminate_window_clause(expression: exp.Expression) -> exp.Expression: + """Eliminates the `WINDOW` query clause by inling each named window.""" + if isinstance(expression, exp.Select) and expression.args.get("windows"): + from bigframes_vendored.sqlglot.optimizer.scope import find_all_in_scope + + windows = expression.args["windows"] + expression.set("windows", None) + + window_expression: t.Dict[str, exp.Expression] = {} + + def _inline_inherited_window(window: exp.Expression) -> None: + inherited_window = window_expression.get(window.alias.lower()) + if not inherited_window: + return + + window.set("alias", None) + for key in ("partition_by", "order", "spec"): + arg = inherited_window.args.get(key) + if arg: + window.set(key, arg.copy()) + + for window in windows: + _inline_inherited_window(window) + window_expression[window.name.lower()] = window + + for window in find_all_in_scope(expression, exp.Window): + _inline_inherited_window(window) + + return expression + + +def inherit_struct_field_names(expression: exp.Expression) -> exp.Expression: + """ + Inherit field names from the first struct in an array. + + BigQuery supports implicitly inheriting names from the first STRUCT in an array: + + Example: + ARRAY[ + STRUCT('Alice' AS name, 85 AS score), -- defines names + STRUCT('Bob', 92), -- inherits names + STRUCT('Diana', 95) -- inherits names + ] + + This transformation makes the field names explicit on all structs by adding + PropertyEQ nodes, in order to facilitate transpilation to other dialects. + + Args: + expression: The expression tree to transform + + Returns: + The modified expression with field names inherited in all structs + """ + if ( + isinstance(expression, exp.Array) + and expression.args.get("struct_name_inheritance") + and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct) + and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions) + ): + field_names = [fld.this for fld in first_item.expressions] + + # Apply field names to subsequent structs that don't have them + for struct in expression.expressions[1:]: + if not isinstance(struct, exp.Struct) or len(struct.expressions) != len( + field_names + ): + continue + + # Convert unnamed expressions to PropertyEQ with inherited names + new_expressions = [] + for i, expr in enumerate(struct.expressions): + if not isinstance(expr, exp.PropertyEQ): + # Create PropertyEQ: field_name := value + new_expressions.append( + exp.PropertyEQ( + this=exp.Identifier(this=field_names[i].copy()), + expression=expr, + ) + ) + else: + new_expressions.append(expr) + + struct.set("expressions", new_expressions) + + return expression diff --git a/third_party/bigframes_vendored/sqlglot/trie.py b/third_party/bigframes_vendored/sqlglot/trie.py new file mode 100644 index 0000000000..bbdefbb0cd --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/trie.py @@ -0,0 +1,81 @@ +from enum import auto, Enum +import typing as t + +key = t.Sequence[t.Hashable] + + +class TrieResult(Enum): + FAILED = auto() + PREFIX = auto() + EXISTS = auto() + + +def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Dict: + """ + Creates a new trie out of a collection of keywords. + + The trie is represented as a sequence of nested dictionaries keyed by either single + character strings, or by 0, which is used to designate that a keyword is in the trie. + + Example: + >>> new_trie(["bla", "foo", "blab"]) + {'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}} + + Args: + keywords: the keywords to create the trie from. + trie: a trie to mutate instead of creating a new one + + Returns: + The trie corresponding to `keywords`. + """ + trie = {} if trie is None else trie + + for key in keywords: + current = trie + for char in key: + current = current.setdefault(char, {}) + + current[0] = True + + return trie + + +def in_trie(trie: t.Dict, key: key) -> t.Tuple[TrieResult, t.Dict]: + """ + Checks whether a key is in a trie. + + Examples: + >>> in_trie(new_trie(["cat"]), "bob") + (, {'c': {'a': {'t': {0: True}}}}) + + >>> in_trie(new_trie(["cat"]), "ca") + (, {'t': {0: True}}) + + >>> in_trie(new_trie(["cat"]), "cat") + (, {0: True}) + + Args: + trie: The trie to be searched. + key: The target key. + + Returns: + A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point + where the search stops, and `value` is a TrieResult value that can be one of: + + - TrieResult.FAILED: the search was unsuccessful + - TrieResult.PREFIX: `value` is a prefix of a keyword in `trie` + - TrieResult.EXISTS: `key` exists in `trie` + """ + if not key: + return (TrieResult.FAILED, trie) + + current = trie + for char in key: + if char not in current: + return (TrieResult.FAILED, current) + current = current[char] + + if 0 in current: + return (TrieResult.EXISTS, current) + + return (TrieResult.PREFIX, current) diff --git a/third_party/bigframes_vendored/sqlglot/typing/__init__.py b/third_party/bigframes_vendored/sqlglot/typing/__init__.py new file mode 100644 index 0000000000..3bc44218f7 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/typing/__init__.py @@ -0,0 +1,358 @@ +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.helper import subclasses + +ExpressionMetadataType = t.Dict[type[exp.Expression], t.Dict[str, t.Any]] + +TIMESTAMP_EXPRESSIONS = { + exp.CurrentTimestamp, + exp.StrToTime, + exp.TimeStrToTime, + exp.TimestampAdd, + exp.TimestampSub, + exp.UnixToTime, +} + +EXPRESSION_METADATA: ExpressionMetadataType = { + **{ + expr_type: {"annotator": lambda self, e: self._annotate_binary(e)} + for expr_type in subclasses(exp.__name__, exp.Binary) + }, + **{ + expr_type: {"annotator": lambda self, e: self._annotate_unary(e)} + for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) + }, + **{ + expr_type: {"returns": exp.DataType.Type.BIGINT} + for expr_type in { + exp.ApproxDistinct, + exp.ArraySize, + exp.CountIf, + exp.Int64, + exp.Length, + exp.UnixDate, + exp.UnixSeconds, + exp.UnixMicros, + exp.UnixMillis, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.BINARY} + for expr_type in { + exp.FromBase32, + exp.FromBase64, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.BOOLEAN} + for expr_type in { + exp.All, + exp.Any, + exp.Between, + exp.Boolean, + exp.Contains, + exp.EndsWith, + exp.Exists, + exp.In, + exp.LogicalAnd, + exp.LogicalOr, + exp.RegexpLike, + exp.StartsWith, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.DATE} + for expr_type in { + exp.CurrentDate, + exp.Date, + exp.DateFromParts, + exp.DateStrToDate, + exp.DiToDate, + exp.LastDay, + exp.StrToDate, + exp.TimeStrToDate, + exp.TsOrDsToDate, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.DATETIME} + for expr_type in { + exp.CurrentDatetime, + exp.Datetime, + exp.DatetimeAdd, + exp.DatetimeSub, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.DOUBLE} + for expr_type in { + exp.ApproxQuantile, + exp.Avg, + exp.Exp, + exp.Ln, + exp.Log, + exp.Pi, + exp.Pow, + exp.Quantile, + exp.Radians, + exp.Round, + exp.SafeDivide, + exp.Sqrt, + exp.Stddev, + exp.StddevPop, + exp.StddevSamp, + exp.ToDouble, + exp.Variance, + exp.VariancePop, + exp.Skewness, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.INT} + for expr_type in { + exp.Ascii, + exp.Ceil, + exp.DatetimeDiff, + exp.TimestampDiff, + exp.TimeDiff, + exp.Unicode, + exp.DateToDi, + exp.Levenshtein, + exp.Sign, + exp.StrPosition, + exp.TsOrDiToDi, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.INTERVAL} + for expr_type in { + exp.Interval, + exp.JustifyDays, + exp.JustifyHours, + exp.JustifyInterval, + exp.MakeInterval, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.JSON} + for expr_type in { + exp.ParseJSON, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.TIME} + for expr_type in { + exp.CurrentTime, + exp.Time, + exp.TimeAdd, + exp.TimeSub, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.TIMESTAMPLTZ} + for expr_type in { + exp.TimestampLtzFromParts, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.TIMESTAMPTZ} + for expr_type in { + exp.CurrentTimestampLTZ, + exp.TimestampTzFromParts, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.TIMESTAMP} + for expr_type in TIMESTAMP_EXPRESSIONS + }, + **{ + expr_type: {"returns": exp.DataType.Type.TINYINT} + for expr_type in { + exp.Day, + exp.DayOfMonth, + exp.DayOfWeek, + exp.DayOfWeekIso, + exp.DayOfYear, + exp.Month, + exp.Quarter, + exp.Week, + exp.WeekOfYear, + exp.Year, + exp.YearOfWeek, + exp.YearOfWeekIso, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.VARCHAR} + for expr_type in { + exp.ArrayToString, + exp.Concat, + exp.ConcatWs, + exp.Chr, + exp.DateToDateStr, + exp.DPipe, + exp.GroupConcat, + exp.Initcap, + exp.Lower, + exp.Substring, + exp.String, + exp.TimeToStr, + exp.TimeToTimeStr, + exp.Trim, + exp.ToBase32, + exp.ToBase64, + exp.TsOrDsToDateStr, + exp.UnixToStr, + exp.UnixToTimeStr, + exp.Upper, + } + }, + **{ + expr_type: {"annotator": lambda self, e: self._annotate_by_args(e, "this")} + for expr_type in { + exp.Abs, + exp.AnyValue, + exp.ArrayConcatAgg, + exp.ArrayReverse, + exp.ArraySlice, + exp.Filter, + exp.HavingMax, + exp.LastValue, + exp.Limit, + exp.Order, + exp.SortArray, + exp.Window, + } + }, + **{ + expr_type: { + "annotator": lambda self, e: self._annotate_by_args( + e, "this", "expressions" + ) + } + for expr_type in { + exp.ArrayConcat, + exp.Coalesce, + exp.Greatest, + exp.Least, + exp.Max, + exp.Min, + } + }, + **{ + expr_type: {"annotator": lambda self, e: self._annotate_by_array_element(e)} + for expr_type in { + exp.ArrayFirst, + exp.ArrayLast, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.UNKNOWN} + for expr_type in { + exp.Anonymous, + exp.Slice, + } + }, + **{ + expr_type: {"annotator": lambda self, e: self._annotate_timeunit(e)} + for expr_type in { + exp.DateAdd, + exp.DateSub, + exp.DateTrunc, + } + }, + **{ + expr_type: {"annotator": lambda self, e: self._set_type(e, e.args["to"])} + for expr_type in { + exp.Cast, + exp.TryCast, + } + }, + **{ + expr_type: {"annotator": lambda self, e: self._annotate_map(e)} + for expr_type in { + exp.Map, + exp.VarMap, + } + }, + exp.Array: { + "annotator": lambda self, e: self._annotate_by_args( + e, "expressions", array=True + ) + }, + exp.ArrayAgg: { + "annotator": lambda self, e: self._annotate_by_args(e, "this", array=True) + }, + exp.Bracket: {"annotator": lambda self, e: self._annotate_bracket(e)}, + exp.Case: { + "annotator": lambda self, e: self._annotate_by_args( + e, *[if_expr.args["true"] for if_expr in e.args["ifs"]], "default" + ) + }, + exp.Count: { + "annotator": lambda self, e: self._set_type( + e, + exp.DataType.Type.BIGINT + if e.args.get("big_int") + else exp.DataType.Type.INT, + ) + }, + exp.DateDiff: { + "annotator": lambda self, e: self._set_type( + e, + exp.DataType.Type.BIGINT + if e.args.get("big_int") + else exp.DataType.Type.INT, + ) + }, + exp.DataType: {"annotator": lambda self, e: self._set_type(e, e.copy())}, + exp.Div: {"annotator": lambda self, e: self._annotate_div(e)}, + exp.Distinct: { + "annotator": lambda self, e: self._annotate_by_args(e, "expressions") + }, + exp.Dot: {"annotator": lambda self, e: self._annotate_dot(e)}, + exp.Explode: {"annotator": lambda self, e: self._annotate_explode(e)}, + exp.Extract: {"annotator": lambda self, e: self._annotate_extract(e)}, + exp.GenerateSeries: { + "annotator": lambda self, e: self._annotate_by_args( + e, "start", "end", "step", array=True + ) + }, + exp.GenerateDateArray: { + "annotator": lambda self, e: self._set_type( + e, exp.DataType.build("ARRAY") + ) + }, + exp.GenerateTimestampArray: { + "annotator": lambda self, e: self._set_type( + e, exp.DataType.build("ARRAY") + ) + }, + exp.If: {"annotator": lambda self, e: self._annotate_by_args(e, "true", "false")}, + exp.Literal: {"annotator": lambda self, e: self._annotate_literal(e)}, + exp.Null: {"returns": exp.DataType.Type.NULL}, + exp.Nullif: { + "annotator": lambda self, e: self._annotate_by_args(e, "this", "expression") + }, + exp.PropertyEQ: { + "annotator": lambda self, e: self._annotate_by_args(e, "expression") + }, + exp.Struct: {"annotator": lambda self, e: self._annotate_struct(e)}, + exp.Sum: { + "annotator": lambda self, e: self._annotate_by_args( + e, "this", "expressions", promote=True + ) + }, + exp.Timestamp: { + "annotator": lambda self, e: self._set_type( + e, + exp.DataType.Type.TIMESTAMPTZ + if e.args.get("with_tz") + else exp.DataType.Type.TIMESTAMP, + ) + }, + exp.ToMap: {"annotator": lambda self, e: self._annotate_to_map(e)}, + exp.Unnest: {"annotator": lambda self, e: self._annotate_unnest(e)}, + exp.Subquery: {"annotator": lambda self, e: self._annotate_subquery(e)}, +} diff --git a/third_party/bigframes_vendored/sqlglot/typing/bigquery.py b/third_party/bigframes_vendored/sqlglot/typing/bigquery.py new file mode 100644 index 0000000000..f3bff12bdd --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/typing/bigquery.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.typing import EXPRESSION_METADATA, TIMESTAMP_EXPRESSIONS + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.optimizer.annotate_types import TypeAnnotator + + +def _annotate_math_functions( + self: TypeAnnotator, expression: exp.Expression +) -> exp.Expression: + """ + Many BigQuery math functions such as CEIL, FLOOR etc follow this return type convention: + +---------+---------+---------+------------+---------+ + | INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 | + +---------+---------+---------+------------+---------+ + | OUTPUT | FLOAT64 | NUMERIC | BIGNUMERIC | FLOAT64 | + +---------+---------+---------+------------+---------+ + """ + this: exp.Expression = expression.this + + self._set_type( + expression, + exp.DataType.Type.DOUBLE + if this.is_type(*exp.DataType.INTEGER_TYPES) + else this.type, + ) + return expression + + +def _annotate_safe_divide( + self: TypeAnnotator, expression: exp.SafeDivide +) -> exp.Expression: + """ + +------------+------------+------------+-------------+---------+ + | INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 | + +------------+------------+------------+-------------+---------+ + | INT64 | FLOAT64 | NUMERIC | BIGNUMERIC | FLOAT64 | + | NUMERIC | NUMERIC | NUMERIC | BIGNUMERIC | FLOAT64 | + | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | FLOAT64 | + | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 | + +------------+------------+------------+-------------+---------+ + """ + if expression.this.is_type( + *exp.DataType.INTEGER_TYPES + ) and expression.expression.is_type(*exp.DataType.INTEGER_TYPES): + return self._set_type(expression, exp.DataType.Type.DOUBLE) + + return _annotate_by_args_with_coerce(self, expression) + + +def _annotate_by_args_with_coerce( + self: TypeAnnotator, expression: exp.Expression +) -> exp.Expression: + """ + +------------+------------+------------+-------------+---------+ + | INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 | + +------------+------------+------------+-------------+---------+ + | INT64 | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 | + | NUMERIC | NUMERIC | NUMERIC | BIGNUMERIC | FLOAT64 | + | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | FLOAT64 | + | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 | + +------------+------------+------------+-------------+---------+ + """ + self._set_type( + expression, self._maybe_coerce(expression.this.type, expression.expression.type) + ) + return expression + + +def _annotate_by_args_approx_top( + self: TypeAnnotator, expression: exp.ApproxTopK +) -> exp.ApproxTopK: + struct_type = exp.DataType( + this=exp.DataType.Type.STRUCT, + expressions=[expression.this.type, exp.DataType(this=exp.DataType.Type.BIGINT)], + nested=True, + ) + self._set_type( + expression, + exp.DataType( + this=exp.DataType.Type.ARRAY, expressions=[struct_type], nested=True + ), + ) + + return expression + + +def _annotate_concat(self: TypeAnnotator, expression: exp.Concat) -> exp.Concat: + annotated = self._annotate_by_args(expression, "expressions") + + # Args must be BYTES or types that can be cast to STRING, return type is either BYTES or STRING + # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#concat + if not annotated.is_type(exp.DataType.Type.BINARY, exp.DataType.Type.UNKNOWN): + self._set_type(annotated, exp.DataType.Type.VARCHAR) + + return annotated + + +def _annotate_array(self: TypeAnnotator, expression: exp.Array) -> exp.Array: + array_args = expression.expressions + + # BigQuery behaves as follows: + # + # SELECT t, TYPEOF(t) FROM (SELECT 'foo') AS t -- foo, STRUCT + # SELECT ARRAY(SELECT 'foo'), TYPEOF(ARRAY(SELECT 'foo')) -- foo, ARRAY + # ARRAY(SELECT ... UNION ALL SELECT ...) -- ARRAY + if len(array_args) == 1: + unnested = array_args[0].unnest() + projection_type: t.Optional[exp.DataType | exp.DataType.Type] = None + + # Handle ARRAY(SELECT ...) - single SELECT query + if isinstance(unnested, exp.Select): + if ( + (query_type := unnested.meta.get("query_type")) is not None + and query_type.is_type(exp.DataType.Type.STRUCT) + and len(query_type.expressions) == 1 + and isinstance(col_def := query_type.expressions[0], exp.ColumnDef) + and (col_type := col_def.kind) is not None + and not col_type.is_type(exp.DataType.Type.UNKNOWN) + ): + projection_type = col_type + + # Handle ARRAY(SELECT ... UNION ALL SELECT ...) - set operations + elif isinstance(unnested, exp.SetOperation): + # Get all column types for the SetOperation + col_types = self._get_setop_column_types(unnested) + # For ARRAY constructor, there should only be one projection + # https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/array_functions#array + if col_types and unnested.left.selects: + first_col_name = unnested.left.selects[0].alias_or_name + projection_type = col_types.get(first_col_name) + + # If we successfully determine a projection type and it's not UNKNOWN, wrap it in ARRAY + if projection_type and not ( + ( + isinstance(projection_type, exp.DataType) + and projection_type.is_type(exp.DataType.Type.UNKNOWN) + ) + or projection_type == exp.DataType.Type.UNKNOWN + ): + element_type = ( + projection_type.copy() + if isinstance(projection_type, exp.DataType) + else exp.DataType(this=projection_type) + ) + array_type = exp.DataType( + this=exp.DataType.Type.ARRAY, + expressions=[element_type], + nested=True, + ) + return self._set_type(expression, array_type) + + return self._annotate_by_args(expression, "expressions", array=True) + + +EXPRESSION_METADATA = { + **EXPRESSION_METADATA, + **{ + expr_type: {"annotator": lambda self, e: _annotate_math_functions(self, e)} + for expr_type in { + exp.Avg, + exp.Ceil, + exp.Exp, + exp.Floor, + exp.Ln, + exp.Log, + exp.Round, + exp.Sqrt, + } + }, + **{ + expr_type: {"annotator": lambda self, e: self._annotate_by_args(e, "this")} + for expr_type in { + exp.Abs, + exp.ArgMax, + exp.ArgMin, + exp.DateTrunc, + exp.DatetimeTrunc, + exp.FirstValue, + exp.GroupConcat, + exp.IgnoreNulls, + exp.JSONExtract, + exp.Lead, + exp.Left, + exp.Lower, + exp.NthValue, + exp.Pad, + exp.PercentileDisc, + exp.RegexpExtract, + exp.RegexpReplace, + exp.Repeat, + exp.Replace, + exp.RespectNulls, + exp.Reverse, + exp.Right, + exp.SafeNegate, + exp.Sign, + exp.Substring, + exp.TimestampTrunc, + exp.Translate, + exp.Trim, + exp.Upper, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.BIGINT} + for expr_type in { + exp.Ascii, + exp.BitwiseAndAgg, + exp.BitwiseCount, + exp.BitwiseOrAgg, + exp.BitwiseXorAgg, + exp.ByteLength, + exp.DenseRank, + exp.FarmFingerprint, + exp.Grouping, + exp.LaxInt64, + exp.Length, + exp.Ntile, + exp.Rank, + exp.RangeBucket, + exp.RegexpInstr, + exp.RowNumber, + exp.Unicode, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.BINARY} + for expr_type in { + exp.ByteString, + exp.CodePointsToBytes, + exp.MD5Digest, + exp.SHA, + exp.SHA2, + exp.SHA1Digest, + exp.SHA2Digest, + exp.Unhex, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.BOOLEAN} + for expr_type in { + exp.IsInf, + exp.IsNan, + exp.JSONBool, + exp.LaxBool, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.DATETIME} + for expr_type in { + exp.ParseDatetime, + exp.TimestampFromParts, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.DOUBLE} + for expr_type in { + exp.Acos, + exp.Acosh, + exp.Asin, + exp.Asinh, + exp.Atan, + exp.Atan2, + exp.Atanh, + exp.Cbrt, + exp.Corr, + exp.CosineDistance, + exp.Cot, + exp.Coth, + exp.CovarPop, + exp.CovarSamp, + exp.Csc, + exp.Csch, + exp.CumeDist, + exp.EuclideanDistance, + exp.Float64, + exp.LaxFloat64, + exp.PercentRank, + exp.Rand, + exp.Sec, + exp.Sech, + exp.Sin, + exp.Sinh, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.JSON} + for expr_type in { + exp.JSONArray, + exp.JSONArrayAppend, + exp.JSONArrayInsert, + exp.JSONObject, + exp.JSONRemove, + exp.JSONSet, + exp.JSONStripNulls, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.TIME} + for expr_type in { + exp.ParseTime, + exp.TimeFromParts, + exp.TimeTrunc, + exp.TsOrDsToTime, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.VARCHAR} + for expr_type in { + exp.CodePointsToString, + exp.Format, + exp.JSONExtractScalar, + exp.JSONType, + exp.LaxString, + exp.LowerHex, + exp.MD5, + exp.NetHost, + exp.Normalize, + exp.SafeConvertBytesToString, + exp.Soundex, + exp.Uuid, + } + }, + **{ + expr_type: {"annotator": lambda self, e: _annotate_by_args_with_coerce(self, e)} + for expr_type in { + exp.PercentileCont, + exp.SafeAdd, + exp.SafeDivide, + exp.SafeMultiply, + exp.SafeSubtract, + } + }, + **{ + expr_type: { + "annotator": lambda self, e: self._annotate_by_args(e, "this", array=True) + } + for expr_type in { + exp.ApproxQuantiles, + exp.JSONExtractArray, + exp.RegexpExtractAll, + exp.Split, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.TIMESTAMPTZ} + for expr_type in TIMESTAMP_EXPRESSIONS + }, + exp.ApproxTopK: { + "annotator": lambda self, e: _annotate_by_args_approx_top(self, e) + }, + exp.ApproxTopSum: { + "annotator": lambda self, e: _annotate_by_args_approx_top(self, e) + }, + exp.Array: {"annotator": _annotate_array}, + exp.ArrayConcat: { + "annotator": lambda self, e: self._annotate_by_args(e, "this", "expressions") + }, + exp.Concat: {"annotator": _annotate_concat}, + exp.DateFromUnixDate: {"returns": exp.DataType.Type.DATE}, + exp.GenerateTimestampArray: { + "annotator": lambda self, e: self._set_type( + e, exp.DataType.build("ARRAY", dialect="bigquery") + ) + }, + exp.JSONFormat: { + "annotator": lambda self, e: self._set_type( + e, + exp.DataType.Type.JSON + if e.args.get("to_json") + else exp.DataType.Type.VARCHAR, + ) + }, + exp.JSONKeysAtDepth: { + "annotator": lambda self, e: self._set_type( + e, exp.DataType.build("ARRAY", dialect="bigquery") + ) + }, + exp.JSONValueArray: { + "annotator": lambda self, e: self._set_type( + e, exp.DataType.build("ARRAY", dialect="bigquery") + ) + }, + exp.Lag: { + "annotator": lambda self, e: self._annotate_by_args(e, "this", "default") + }, + exp.ParseBignumeric: {"returns": exp.DataType.Type.BIGDECIMAL}, + exp.ParseNumeric: {"returns": exp.DataType.Type.DECIMAL}, + exp.SafeDivide: {"annotator": lambda self, e: _annotate_safe_divide(self, e)}, + exp.ToCodePoints: { + "annotator": lambda self, e: self._set_type( + e, exp.DataType.build("ARRAY", dialect="bigquery") + ) + }, +} diff --git a/third_party/bigframes_vendored/sqlglot/typing/hive.py b/third_party/bigframes_vendored/sqlglot/typing/hive.py new file mode 100644 index 0000000000..1cc24c670c --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/typing/hive.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.typing import EXPRESSION_METADATA + +EXPRESSION_METADATA = { + **EXPRESSION_METADATA, + exp.If: { + "annotator": lambda self, e: self._annotate_by_args( + e, "true", "false", promote=True + ) + }, + exp.Coalesce: { + "annotator": lambda self, e: self._annotate_by_args( + e, "this", "expressions", promote=True + ) + }, +} diff --git a/third_party/bigframes_vendored/sqlglot/typing/presto.py b/third_party/bigframes_vendored/sqlglot/typing/presto.py new file mode 100644 index 0000000000..41bdb167b6 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/typing/presto.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.typing import EXPRESSION_METADATA + +EXPRESSION_METADATA = { + **EXPRESSION_METADATA, + **{ + expr_type: {"annotator": lambda self, e: self._annotate_by_args(e, "this")} + for expr_type in { + exp.Abs, + exp.Ceil, + exp.Floor, + exp.Round, + exp.Sign, + } + }, + exp.Mod: { + "annotator": lambda self, e: self._annotate_by_args(e, "this", "expression") + }, + exp.Rand: { + "annotator": lambda self, e: self._annotate_by_args(e, "this") + if e.this + else self._set_type(e, exp.DataType.Type.DOUBLE) + }, +} diff --git a/third_party/bigframes_vendored/sqlglot/typing/snowflake.py b/third_party/bigframes_vendored/sqlglot/typing/snowflake.py new file mode 100644 index 0000000000..934910e2af --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/typing/snowflake.py @@ -0,0 +1,545 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.helper import seq_get +from bigframes_vendored.sqlglot.typing import EXPRESSION_METADATA + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot.optimizer.annotate_types import TypeAnnotator + +DATE_PARTS = {"DAY", "WEEK", "MONTH", "QUARTER", "YEAR"} + +MAX_PRECISION = 38 + +MAX_SCALE = 37 + + +def _annotate_reverse(self: TypeAnnotator, expression: exp.Reverse) -> exp.Reverse: + expression = self._annotate_by_args(expression, "this") + if expression.is_type(exp.DataType.Type.NULL): + # Snowflake treats REVERSE(NULL) as a VARCHAR + self._set_type(expression, exp.DataType.Type.VARCHAR) + + return expression + + +def _annotate_timestamp_from_parts( + self: TypeAnnotator, expression: exp.TimestampFromParts +) -> exp.TimestampFromParts: + """Annotate TimestampFromParts with correct type based on arguments. + TIMESTAMP_FROM_PARTS with time_zone -> TIMESTAMPTZ + TIMESTAMP_FROM_PARTS without time_zone -> TIMESTAMP (defaults to TIMESTAMP_NTZ) + """ + if expression.args.get("zone"): + self._set_type(expression, exp.DataType.Type.TIMESTAMPTZ) + else: + self._set_type(expression, exp.DataType.Type.TIMESTAMP) + + return expression + + +def _annotate_date_or_time_add( + self: TypeAnnotator, expression: exp.Expression +) -> exp.Expression: + if ( + expression.this.is_type(exp.DataType.Type.DATE) + and expression.text("unit").upper() not in DATE_PARTS + ): + self._set_type(expression, exp.DataType.Type.TIMESTAMPNTZ) + else: + self._annotate_by_args(expression, "this") + return expression + + +def _annotate_decode_case( + self: TypeAnnotator, expression: exp.DecodeCase +) -> exp.DecodeCase: + """Annotate DecodeCase with the type inferred from return values only. + + DECODE uses the format: DECODE(expr, val1, ret1, val2, ret2, ..., default) + We only look at the return values (ret1, ret2, ..., default) to determine the type, + not the comparison values (val1, val2, ...) or the expression being compared. + """ + expressions = expression.expressions + + # Return values are at indices 2, 4, 6, ... and the last element (if even length) + # DECODE(expr, val1, ret1, val2, ret2, ..., default) + return_types = [expressions[i].type for i in range(2, len(expressions), 2)] + + # If the total number of expressions is even, the last one is the default + # Example: + # DECODE(x, 1, 'a', 2, 'b') -> len=5 (odd), no default + # DECODE(x, 1, 'a', 2, 'b', 'default') -> len=6 (even), has default + if len(expressions) % 2 == 0: + return_types.append(expressions[-1].type) + + # Determine the common type from all return values + last_type = None + for ret_type in return_types: + last_type = self._maybe_coerce(last_type or ret_type, ret_type) + + self._set_type(expression, last_type) + return expression + + +def _annotate_arg_max_min(self, expression): + self._set_type( + expression, + exp.DataType.Type.ARRAY + if expression.args.get("count") + else expression.this.type, + ) + return expression + + +def _annotate_within_group( + self: TypeAnnotator, expression: exp.WithinGroup +) -> exp.WithinGroup: + """Annotate WithinGroup with correct type based on the inner function. + + 1) Annotate args first + 2) Check if this is PercentileDisc/PercentileCont and if so, re-annotate its type to match the ordered expression's type + """ + + if ( + isinstance(expression.this, (exp.PercentileDisc, exp.PercentileCont)) + and isinstance(order_expr := expression.expression, exp.Order) + and len(order_expr.expressions) == 1 + and isinstance(ordered_expr := order_expr.expressions[0], exp.Ordered) + ): + self._set_type(expression, ordered_expr.this.type) + + return expression + + +def _annotate_median(self: TypeAnnotator, expression: exp.Median) -> exp.Median: + """Annotate MEDIAN function with correct return type. + + Based on Snowflake documentation: + - If the expr is FLOAT/DOUBLE -> annotate as DOUBLE (FLOAT is a synonym for DOUBLE) + - If the expr is NUMBER(p, s) -> annotate as NUMBER(min(p+3, 38), min(s+3, 37)) + """ + # First annotate the argument to get its type + expression = self._annotate_by_args(expression, "this") + + # Get the input type + input_type = expression.this.type + + if input_type.is_type(exp.DataType.Type.DOUBLE): + # If input is FLOAT/DOUBLE, return DOUBLE (FLOAT is normalized to DOUBLE in Snowflake) + self._set_type(expression, exp.DataType.Type.DOUBLE) + else: + # If input is NUMBER(p, s), return NUMBER(min(p+3, 38), min(s+3, 37)) + exprs = input_type.expressions + + precision_expr = seq_get(exprs, 0) + precision = precision_expr.this.to_py() if precision_expr else MAX_PRECISION + + scale_expr = seq_get(exprs, 1) + scale = scale_expr.this.to_py() if scale_expr else 0 + + new_precision = min(precision + 3, MAX_PRECISION) + new_scale = min(scale + 3, MAX_SCALE) + + # Build the new NUMBER type + new_type = exp.DataType.build( + f"NUMBER({new_precision}, {new_scale})", dialect="snowflake" + ) + self._set_type(expression, new_type) + + return expression + + +def _annotate_variance( + self: TypeAnnotator, expression: exp.Expression +) -> exp.Expression: + """Annotate variance functions (VAR_POP, VAR_SAMP, VARIANCE, VARIANCE_POP) with correct return type. + + Based on Snowflake behavior: + - DECFLOAT -> DECFLOAT(38) + - FLOAT/DOUBLE -> FLOAT + - INT, NUMBER(p, 0) -> NUMBER(38, 6) + - NUMBER(p, s) -> NUMBER(38, max(12, s)) + """ + # First annotate the argument to get its type + expression = self._annotate_by_args(expression, "this") + + # Get the input type + input_type = expression.this.type + + # Special case: DECFLOAT -> DECFLOAT(38) + if input_type.is_type(exp.DataType.Type.DECFLOAT): + self._set_type(expression, exp.DataType.build("DECFLOAT", dialect="snowflake")) + # Special case: FLOAT/DOUBLE -> DOUBLE + elif input_type.is_type(exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE): + self._set_type(expression, exp.DataType.Type.DOUBLE) + # For NUMBER types: determine the scale + else: + exprs = input_type.expressions + scale_expr = seq_get(exprs, 1) + scale = scale_expr.this.to_py() if scale_expr else 0 + + # If scale is 0 (INT, BIGINT, NUMBER(p,0)): return NUMBER(38, 6) + # Otherwise, Snowflake appears to assign scale through the formula MAX(12, s) + new_scale = 6 if scale == 0 else max(12, scale) + + # Build the new NUMBER type + new_type = exp.DataType.build( + f"NUMBER({MAX_PRECISION}, {new_scale})", dialect="snowflake" + ) + self._set_type(expression, new_type) + + return expression + + +def _annotate_math_with_float_decfloat( + self: TypeAnnotator, expression: exp.Expression +) -> exp.Expression: + """Annotate math functions that preserve DECFLOAT but return DOUBLE for others. + + In Snowflake, trigonometric and exponential math functions: + - If input is DECFLOAT -> return DECFLOAT + - For integer types (INT, BIGINT, etc.) -> return DOUBLE + - For other numeric types (NUMBER, DECIMAL, DOUBLE) -> return DOUBLE + """ + expression = self._annotate_by_args(expression, "this") + + # If input is DECFLOAT, preserve + if expression.this.is_type(exp.DataType.Type.DECFLOAT): + self._set_type(expression, expression.this.type) + else: + # For all other types (integers, decimals, etc.), return DOUBLE + self._set_type(expression, exp.DataType.Type.DOUBLE) + + return expression + + +def _annotate_str_to_time( + self: TypeAnnotator, expression: exp.StrToTime +) -> exp.StrToTime: + # target_type is stored as a DataType instance + target_type_arg = expression.args.get("target_type") + target_type = ( + target_type_arg.this + if isinstance(target_type_arg, exp.DataType) + else exp.DataType.Type.TIMESTAMP + ) + self._set_type(expression, target_type) + return expression + + +EXPRESSION_METADATA = { + **EXPRESSION_METADATA, + **{ + expr_type: {"annotator": lambda self, e: self._annotate_by_args(e, "this")} + for expr_type in { + exp.AddMonths, + exp.Ceil, + exp.DateTrunc, + exp.Floor, + exp.Left, + exp.Mode, + exp.Pad, + exp.Right, + exp.Round, + exp.Stuff, + exp.Substring, + exp.TimeSlice, + exp.TimestampTrunc, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.ARRAY} + for expr_type in ( + exp.ApproxTopK, + exp.ApproxTopKEstimate, + exp.Array, + exp.ArrayAgg, + exp.ArrayConstructCompact, + exp.ArrayUniqueAgg, + exp.ArrayUnionAgg, + exp.MapKeys, + exp.RegexpExtractAll, + exp.Split, + exp.StringToArray, + ) + }, + **{ + expr_type: {"returns": exp.DataType.Type.BIGINT} + for expr_type in { + exp.BitmapBitPosition, + exp.BitmapBucketNumber, + exp.BitmapCount, + exp.Factorial, + exp.GroupingId, + exp.MD5NumberLower64, + exp.MD5NumberUpper64, + exp.Rand, + exp.Zipf, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.BINARY} + for expr_type in { + exp.Base64DecodeBinary, + exp.BitmapConstructAgg, + exp.BitmapOrAgg, + exp.Compress, + exp.DecompressBinary, + exp.HexString, + exp.MD5Digest, + exp.SHA1Digest, + exp.SHA2Digest, + exp.ToBinary, + exp.TryBase64DecodeBinary, + exp.TryHexDecodeBinary, + exp.Unhex, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.BOOLEAN} + for expr_type in { + exp.Booland, + exp.Boolnot, + exp.Boolor, + exp.BoolxorAgg, + exp.EqualNull, + exp.IsNullValue, + exp.MapContainsKey, + exp.Search, + exp.SearchIp, + exp.ToBoolean, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.DATE} + for expr_type in { + exp.NextDay, + exp.PreviousDay, + } + }, + **{ + expr_type: { + "annotator": lambda self, e: self._set_type( + e, exp.DataType.build("NUMBER", dialect="snowflake") + ) + } + for expr_type in ( + exp.BitwiseAndAgg, + exp.BitwiseOrAgg, + exp.BitwiseXorAgg, + exp.RegexpCount, + exp.RegexpInstr, + exp.ToNumber, + ) + }, + **{ + expr_type: {"returns": exp.DataType.Type.DOUBLE} + for expr_type in { + exp.ApproxPercentileEstimate, + exp.ApproximateSimilarity, + exp.Asinh, + exp.Atanh, + exp.Cbrt, + exp.Cosh, + exp.CosineDistance, + exp.DotProduct, + exp.EuclideanDistance, + exp.ManhattanDistance, + exp.MonthsBetween, + exp.Normal, + exp.Sinh, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.DECFLOAT} + for expr_type in { + exp.ToDecfloat, + exp.TryToDecfloat, + } + }, + **{ + expr_type: {"annotator": _annotate_math_with_float_decfloat} + for expr_type in { + exp.Acos, + exp.Asin, + exp.Atan, + exp.Atan2, + exp.Cos, + exp.Cot, + exp.Degrees, + exp.Exp, + exp.Ln, + exp.Log, + exp.Pow, + exp.Radians, + exp.RegrAvgx, + exp.RegrAvgy, + exp.RegrCount, + exp.RegrIntercept, + exp.RegrR2, + exp.RegrSlope, + exp.RegrSxx, + exp.RegrSxy, + exp.RegrSyy, + exp.RegrValx, + exp.RegrValy, + exp.Sin, + exp.Sqrt, + exp.Tan, + exp.Tanh, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.INT} + for expr_type in { + exp.Ascii, + exp.BitLength, + exp.ByteLength, + exp.Getbit, + exp.Grouping, + exp.Hour, + exp.JarowinklerSimilarity, + exp.Length, + exp.Levenshtein, + exp.MapSize, + exp.Minute, + exp.RtrimmedLength, + exp.Second, + exp.StrPosition, + exp.Unicode, + exp.WidthBucket, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.OBJECT} + for expr_type in { + exp.ApproxPercentileAccumulate, + exp.ApproxPercentileCombine, + exp.ApproxTopKAccumulate, + exp.ApproxTopKCombine, + exp.ObjectAgg, + exp.ParseIp, + exp.ParseUrl, + exp.XMLGet, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.MAP} + for expr_type in { + exp.MapCat, + exp.MapDelete, + exp.MapInsert, + exp.MapPick, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.FILE} + for expr_type in { + exp.ToFile, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.TIME} + for expr_type in { + exp.TimeFromParts, + exp.TsOrDsToTime, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.VARCHAR} + for expr_type in { + exp.AIAgg, + exp.AIClassify, + exp.AISummarizeAgg, + exp.Base64DecodeString, + exp.Base64Encode, + exp.CheckJson, + exp.CheckXml, + exp.Chr, + exp.Collate, + exp.Collation, + exp.CurrentAccount, + exp.CurrentAccountName, + exp.CurrentAvailableRoles, + exp.CurrentClient, + exp.CurrentDatabase, + exp.CurrentIpAddress, + exp.CurrentSchemas, + exp.CurrentSecondaryRoles, + exp.CurrentSession, + exp.CurrentStatement, + exp.CurrentVersion, + exp.CurrentTransaction, + exp.CurrentWarehouse, + exp.CurrentOrganizationUser, + exp.CurrentRegion, + exp.CurrentRole, + exp.CurrentRoleType, + exp.CurrentOrganizationName, + exp.DecompressString, + exp.HexDecodeString, + exp.HexEncode, + exp.Initcap, + exp.MD5, + exp.Monthname, + exp.Randstr, + exp.RegexpExtract, + exp.RegexpReplace, + exp.Repeat, + exp.Replace, + exp.SHA, + exp.SHA2, + exp.Soundex, + exp.SoundexP123, + exp.Space, + exp.SplitPart, + exp.Translate, + exp.TryBase64DecodeString, + exp.TryHexDecodeString, + exp.Uuid, + } + }, + **{ + expr_type: {"returns": exp.DataType.Type.VARIANT} + for expr_type in { + exp.Minhash, + exp.MinhashCombine, + } + }, + **{ + expr_type: {"annotator": _annotate_variance} + for expr_type in ( + exp.Variance, + exp.VariancePop, + ) + }, + exp.ArgMax: {"annotator": _annotate_arg_max_min}, + exp.ArgMin: {"annotator": _annotate_arg_max_min}, + exp.ConcatWs: { + "annotator": lambda self, e: self._annotate_by_args(e, "expressions") + }, + exp.ConvertTimezone: { + "annotator": lambda self, e: self._set_type( + e, + exp.DataType.Type.TIMESTAMPNTZ + if e.args.get("source_tz") + else exp.DataType.Type.TIMESTAMPTZ, + ) + }, + exp.DateAdd: {"annotator": _annotate_date_or_time_add}, + exp.DecodeCase: {"annotator": _annotate_decode_case}, + exp.HashAgg: { + "annotator": lambda self, e: self._set_type( + e, exp.DataType.build("NUMBER(19, 0)", dialect="snowflake") + ) + }, + exp.Median: {"annotator": _annotate_median}, + exp.Reverse: {"annotator": _annotate_reverse}, + exp.StrToTime: {"annotator": _annotate_str_to_time}, + exp.TimeAdd: {"annotator": _annotate_date_or_time_add}, + exp.TimestampFromParts: {"annotator": _annotate_timestamp_from_parts}, + exp.WithinGroup: {"annotator": _annotate_within_group}, +} diff --git a/third_party/bigframes_vendored/sqlglot/typing/spark2.py b/third_party/bigframes_vendored/sqlglot/typing/spark2.py new file mode 100644 index 0000000000..61f2581692 --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/typing/spark2.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import typing as t + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.helper import ensure_list +from bigframes_vendored.sqlglot.typing.hive import ( + EXPRESSION_METADATA as HIVE_EXPRESSION_METADATA, +) + +if t.TYPE_CHECKING: + from bigframes_vendored.sqlglot._typing import E + from bigframes_vendored.sqlglot.optimizer.annotate_types import TypeAnnotator + from bigframes_vendored.sqlglot.typing import ExpressionMetadataType + + +def _annotate_by_similar_args( + self: TypeAnnotator, + expression: E, + *args: str, + target_type: exp.DataType | exp.DataType.Type, +) -> E: + """ + Infers the type of the expression according to the following rules: + - If all args are of the same type OR any arg is of target_type, the expr is inferred as such + - If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN + """ + expressions: t.List[exp.Expression] = [] + for arg in args: + arg_expr = expression.args.get(arg) + expressions.extend(expr for expr in ensure_list(arg_expr) if expr) + + last_datatype = None + + has_unknown = False + for expr in expressions: + if expr.is_type(exp.DataType.Type.UNKNOWN): + has_unknown = True + elif expr.is_type(target_type): + has_unknown = False + last_datatype = target_type + break + else: + last_datatype = expr.type + + self._set_type( + expression, exp.DataType.Type.UNKNOWN if has_unknown else last_datatype + ) + return expression + + +EXPRESSION_METADATA: ExpressionMetadataType = { + **HIVE_EXPRESSION_METADATA, + exp.Substring: {"annotator": lambda self, e: self._annotate_by_args(e, "this")}, + exp.Concat: { + "annotator": lambda self, e: _annotate_by_similar_args( + self, e, "expressions", target_type=exp.DataType.Type.TEXT + ) + }, + exp.Pad: { + "annotator": lambda self, e: _annotate_by_similar_args( + self, e, "this", "fill_pattern", target_type=exp.DataType.Type.TEXT + ) + }, +} diff --git a/third_party/bigframes_vendored/sqlglot/typing/tsql.py b/third_party/bigframes_vendored/sqlglot/typing/tsql.py new file mode 100644 index 0000000000..48aaed603e --- /dev/null +++ b/third_party/bigframes_vendored/sqlglot/typing/tsql.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from bigframes_vendored.sqlglot import exp +from bigframes_vendored.sqlglot.typing import EXPRESSION_METADATA + +EXPRESSION_METADATA = { + **EXPRESSION_METADATA, + exp.Radians: {"annotator": lambda self, e: self._annotate_by_args(e, "this")}, +}