1616
1717from __future__ import annotations
1818
19- from typing import Sequence
19+ from typing import cast , Optional , Sequence , Union
2020
2121import google .cloud .bigquery
2222
2323from bigframes .core .compile .sqlglot import sql
24+ import bigframes .dataframe
2425import bigframes .dtypes
2526import bigframes .operations
2627import bigframes .series
2728
2829
30+ def _format_names (sql_template : str , dataframe : bigframes .dataframe .DataFrame ):
31+ """Turn sql_template from a template that uses names to one that uses
32+ numbers.
33+ """
34+ names_to_numbers = {name : f"{{{ i } }}" for i , name in enumerate (dataframe .columns )}
35+ numbers = [f"{{{ i } }}" for i in range (len (dataframe .columns ))]
36+ return sql_template .format (* numbers , ** names_to_numbers )
37+
38+
2939def sql_scalar (
3040 sql_template : str ,
31- columns : Sequence [bigframes .series .Series ],
41+ columns : Union [bigframes .dataframe .DataFrame , Sequence [bigframes .series .Series ]],
42+ * ,
43+ output_dtype : Optional [bigframes .dtypes .Dtype ] = None ,
3244) -> bigframes .series .Series :
3345 """Create a Series from a SQL template.
3446
@@ -37,6 +49,9 @@ def sql_scalar(
3749 >>> import bigframes.pandas as bpd
3850 >>> import bigframes.bigquery as bbq
3951
52+ Either pass in a sequence of series, in which case use integers in the
53+ format strings.
54+
4055 >>> s = bpd.Series(["1.5", "2.5", "3.5"])
4156 >>> s = s.astype(pd.ArrowDtype(pa.decimal128(38, 9)))
4257 >>> bbq.sql_scalar("ROUND({0}, 0, 'ROUND_HALF_EVEN')", [s])
@@ -45,13 +60,29 @@ def sql_scalar(
4560 2 4.000000000
4661 dtype: decimal128(38, 9)[pyarrow]
4762
63+ Or pass in a DataFrame, in which case use the column names in the format
64+ strings.
65+
66+ >>> df = bpd.DataFrame({"a": ["1.5", "2.5", "3.5"]})
67+ >>> df = df.astype({"a": pd.ArrowDtype(pa.decimal128(38, 9))})
68+ >>> bbq.sql_scalar("ROUND({a}, 0, 'ROUND_HALF_EVEN')", df)
69+ 0 2.000000000
70+ 1 2.000000000
71+ 2 4.000000000
72+ dtype: decimal128(38, 9)[pyarrow]
73+
4874 Args:
4975 sql_template (str):
5076 A SQL format string with Python-style {0} placeholders for each of
5177 the Series objects in ``columns``.
52- columns (Sequence[bigframes.pandas.Series]):
78+ columns (
79+ Sequence[bigframes.pandas.Series] | bigframes.pandas.DataFrame
80+ ):
5381 Series objects representing the column inputs to the
5482 ``sql_template``. Must contain at least one Series.
83+ output_dtype (a BigQuery DataFrames compatible dtype, optional):
84+ If provided, BigQuery DataFrames uses this to determine the output
85+ of the returned Series. This avoids a dry run query.
5586
5687 Returns:
5788 bigframes.pandas.Series:
@@ -60,28 +91,38 @@ def sql_scalar(
6091 Raises:
6192 ValueError: If ``columns`` is empty.
6293 """
94+ if isinstance (columns , bigframes .dataframe .DataFrame ):
95+ sql_template = _format_names (sql_template , columns )
96+ columns = [
97+ cast (bigframes .series .Series , columns [column ]) for column in columns .columns
98+ ]
99+
63100 if len (columns ) == 0 :
64101 raise ValueError ("Must provide at least one column in columns" )
65102
103+ base_series = columns [0 ]
104+
66105 # To integrate this into our expression trees, we need to get the output
67106 # type, so we do some manual compilation and a dry run query to get that.
68107 # Another benefit of this is that if there is a syntax error in the SQL
69108 # template, then this will fail with an error earlier in the process,
70109 # aiding users in debugging.
71- literals_sql = [sql .to_sql (sql .literal (None , column .dtype )) for column in columns ]
72- select_sql = sql_template .format (* literals_sql )
73- dry_run_sql = f"SELECT { select_sql } "
74-
75- # Use the executor directly, because we want the original column IDs, not
76- # the user-friendly column names that block.to_sql_query() would produce.
77- base_series = columns [0 ]
78- bqclient = base_series ._session .bqclient
79- job = bqclient .query (
80- dry_run_sql , job_config = google .cloud .bigquery .QueryJobConfig (dry_run = True )
81- )
82- _ , output_type = bigframes .dtypes .convert_schema_field (job .schema [0 ])
110+ if output_dtype is None :
111+ literals_sql = [
112+ sql .to_sql (sql .literal (None , column .dtype )) for column in columns
113+ ]
114+ select_sql = sql_template .format (* literals_sql )
115+ dry_run_sql = f"SELECT { select_sql } "
116+
117+ # Use the executor directly, because we want the original column IDs, not
118+ # the user-friendly column names that block.to_sql_query() would produce.
119+ bqclient = base_series ._session .bqclient
120+ job = bqclient .query (
121+ dry_run_sql , job_config = google .cloud .bigquery .QueryJobConfig (dry_run = True )
122+ )
123+ _ , output_dtype = bigframes .dtypes .convert_schema_field (job .schema [0 ])
83124
84125 op = bigframes .operations .SqlScalarOp (
85- _output_type = output_type , sql_template = sql_template
126+ _output_type = output_dtype , sql_template = sql_template
86127 )
87128 return base_series ._apply_nary_op (op , columns [1 :])
0 commit comments