From f928c91e2403d3ed8df386aff8a357cef553fa94 Mon Sep 17 00:00:00 2001 From: petrmarinec <54589756+petrmarinec@users.noreply.github.com> Date: Fri, 10 Apr 2026 20:54:51 +0200 Subject: [PATCH] Fix SQL injection in BigQuery ML tools --- src/google/adk/tools/bigquery/query_tool.py | 343 +++++++---- .../bigquery/test_bigquery_query_tool.py | 542 ++++++++++++------ 2 files changed, 595 insertions(+), 290 deletions(-) diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index d5b1264fa5..ea756aae62 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -16,6 +16,7 @@ import functools import json +import re import types from typing import Callable from typing import Optional @@ -30,6 +31,74 @@ from .config import WriteMode BIGQUERY_SESSION_INFO_KEY = "bigquery_session_info" +_FIELD_PATH_RE = re.compile( + r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*" +) + + +def _validate_table_id( + table_id: str, *, project_id: str, argument_name: str +) -> str: + """Validate that an input is a BigQuery table identifier, not a query.""" + if not isinstance(table_id, str): + raise ValueError(f"{argument_name} must be a BigQuery table ID.") + + normalized_table_id = table_id.strip() + if not normalized_table_id: + raise ValueError(f"{argument_name} must be a non-empty BigQuery table ID.") + + trimmed_upper_table_id = normalized_table_id.upper() + if trimmed_upper_table_id.startswith( + "SELECT" + ) or trimmed_upper_table_id.startswith("WITH"): + raise ValueError( + f"{argument_name} must be a BigQuery table ID. SQL query statements" + " are not supported by this tool." + ) + + try: + bigquery.TableReference.from_string( + normalized_table_id, default_project=project_id + ) + except ValueError as ex: + raise ValueError( + f"{argument_name} must be a valid BigQuery table ID." + ) from ex + + return normalized_table_id + + +def _quote_field_path(field_path: str, *, argument_name: str) -> str: + """Quote a BigQuery field path for safe use in ORDER BY clauses.""" + if not isinstance(field_path, str) or not _FIELD_PATH_RE.fullmatch(field_path): + raise ValueError( + f"{argument_name} must be a valid BigQuery column name or field path." + ) + + return ".".join(f"`{part}`" for part in field_path.split(".")) + + +def _make_query_job_config( + *, + dry_run: bool = False, + create_session: bool = False, + connection_properties: Optional[list] = None, + labels: Optional[dict[str, str]] = None, + query_parameters: Optional[list] = None, +) -> bigquery.QueryJobConfig: + """Build a QueryJobConfig without passing unset optional fields.""" + kwargs = {} + if dry_run: + kwargs["dry_run"] = True + if create_session: + kwargs["create_session"] = True + if connection_properties: + kwargs["connection_properties"] = connection_properties + if labels is not None: + kwargs["labels"] = labels + if query_parameters is not None: + kwargs["query_parameters"] = query_parameters + return bigquery.QueryJobConfig(**kwargs) def _execute_sql( @@ -40,6 +109,7 @@ def _execute_sql( tool_context: ToolContext, dry_run: bool = False, caller_id: Optional[str] = None, + query_parameters: Optional[list] = None, ) -> dict: try: # Validate compute project if applicable @@ -81,8 +151,10 @@ def _execute_sql( dry_run_query_job = bq_client.query( query, project=project_id, - job_config=bigquery.QueryJobConfig( - dry_run=True, labels=bq_job_labels + job_config=_make_query_job_config( + dry_run=True, + labels=bq_job_labels, + query_parameters=query_parameters, ), ) if dry_run_query_job.statement_type != "SELECT": @@ -102,8 +174,11 @@ def _execute_sql( session_creator_job = bq_client.query( "SELECT 1", project=project_id, - job_config=bigquery.QueryJobConfig( - dry_run=True, create_session=True, labels=bq_job_labels + job_config=_make_query_job_config( + dry_run=True, + create_session=True, + labels=bq_job_labels, + query_parameters=query_parameters, ), ) bq_session_id = session_creator_job.session_info.session_id @@ -124,10 +199,11 @@ def _execute_sql( dry_run_query_job = bq_client.query( query, project=project_id, - job_config=bigquery.QueryJobConfig( + job_config=_make_query_job_config( dry_run=True, connection_properties=bq_connection_properties, labels=bq_job_labels, + query_parameters=query_parameters, ), ) if ( @@ -148,18 +224,20 @@ def _execute_sql( dry_run_job = bq_client.query( query, project=project_id, - job_config=bigquery.QueryJobConfig( + job_config=_make_query_job_config( dry_run=True, connection_properties=bq_connection_properties, labels=bq_job_labels, + query_parameters=query_parameters, ), ) return {"status": "SUCCESS", "dry_run_info": dry_run_job.to_api_repr()} # Finally execute the query, fetch the result, and return it - job_config = bigquery.QueryJobConfig( + job_config = _make_query_job_config( connection_properties=bq_connection_properties, labels=bq_job_labels, + query_parameters=query_parameters, ) if settings.maximum_bytes_billed: job_config.maximum_bytes_billed = settings.maximum_bytes_billed @@ -781,9 +859,8 @@ def forecast( Args: project_id (str): The GCP project id in which the query should be executed. - history_data (str): The table id of the BigQuery table containing the - history time series data or a query statement that select the history - data. + history_data (str): The table ID of the BigQuery table containing the + history time series data. timestamp_col (str): The name of the column containing the timestamp for each data point. data_col (str): The name of the column containing the numerical values to @@ -827,16 +904,11 @@ def forecast( ] } - Forecast multiple time series using a SQL query as input: + Forecast multiple time series from a BigQuery table: - >>> history_query = ( - ... "SELECT unique_id, timestamp, value " - ... "FROM `my-project.my-dataset.my-timeseries-table` " - ... "WHERE timestamp > '1980-01-01'" - ... ) >>> forecast( ... project_id="my-gcp-project", - ... history_data=history_query, + ... history_data="my-project.my_dataset.my_timeseries_table", ... timestamp_col="timestamp", ... data_col="value", ... id_cols=["unique_id"], @@ -890,13 +962,23 @@ def forecast( """ model = "TimesFM 2.0" confidence_level = 0.95 - trimmed_upper_history_data = history_data.strip().upper() - if trimmed_upper_history_data.startswith( - "SELECT" - ) or trimmed_upper_history_data.startswith("WITH"): - history_data_source = f"({history_data})" - else: - history_data_source = f"TABLE `{history_data}`" + try: + history_data = _validate_table_id( + history_data, project_id=project_id, argument_name="history_data" + ) + except ValueError as ex: + return {"status": "ERROR", "error_details": str(ex)} + + history_data_source = f"TABLE `{history_data}`" + query_parameters = [ + bigquery.ScalarQueryParameter("data_col", "STRING", data_col), + bigquery.ScalarQueryParameter("timestamp_col", "STRING", timestamp_col), + bigquery.ScalarQueryParameter("model", "STRING", model), + bigquery.ScalarQueryParameter("horizon", "INT64", horizon), + bigquery.ScalarQueryParameter( + "confidence_level", "FLOAT64", confidence_level + ), + ] if id_cols: if not all(isinstance(item, str) for item in id_cols): @@ -904,28 +986,30 @@ def forecast( "status": "ERROR", "error_details": "All elements in id_cols must be strings.", } - id_cols_str = "[" + ", ".join([f"'{col}'" for col in id_cols]) + "]" + query_parameters.append( + bigquery.ArrayQueryParameter("id_cols", "STRING", id_cols) + ) query = f""" SELECT * FROM AI.FORECAST( {history_data_source}, - data_col => '{data_col}', - timestamp_col => '{timestamp_col}', - model => '{model}', - id_cols => {id_cols_str}, - horizon => {horizon}, - confidence_level => {confidence_level} + data_col => @data_col, + timestamp_col => @timestamp_col, + model => @model, + id_cols => @id_cols, + horizon => @horizon, + confidence_level => @confidence_level ) """ else: query = f""" SELECT * FROM AI.FORECAST( {history_data_source}, - data_col => '{data_col}', - timestamp_col => '{timestamp_col}', - model => '{model}', - horizon => {horizon}, - confidence_level => {confidence_level} + data_col => @data_col, + timestamp_col => @timestamp_col, + model => @model, + horizon => @horizon, + confidence_level => @confidence_level ) """ return _execute_sql( @@ -935,6 +1019,7 @@ def forecast( settings=settings, tool_context=tool_context, caller_id="forecast", + query_parameters=query_parameters, ) @@ -955,8 +1040,8 @@ def analyze_contribution( Args: project_id (str): The GCP project id in which the query should be executed. - input_data (str): The data that contain the test and control data to - analyze. Can be a fully qualified BigQuery table ID or a SQL query. + input_data (str): The BigQuery table ID that contains the test and + control data to analyze. dimension_id_cols (list[str]): The column names of the dimension columns. contribution_metric (str): The name of the column that contains the metric to analyze. Provides the expression to use to calculate the metric you @@ -1016,38 +1101,14 @@ def analyze_contribution( ] } - Analyze the contribution of different dimensions to the total sales using - a SQL query as input: - - >>> analyze_contribution( - ... project_id="my-gcp-project", - ... input_data="SELECT store_id, product_category, total_sales, " - ... "is_test FROM `my-project.my-dataset.my-sales-table` " - ... "WHERE transaction_date > '2025-01-01'" - ... dimension_id_cols=["store_id", "product_category"], - ... contribution_metric="SUM(total_sales)", - ... is_test_col="is_test" - ... ) - The return is: - { - "status": "SUCCESS", - "rows": [ - { - "store_id": "S2", - "product_category": "Groceries", - "contributors": ["S2", "Groceries"], - "metric_test": 250, - "metric_control": 200, - "difference": 50, - "relative_difference": 0.25, - "unexpected_difference": 10, - "relative_unexpected_difference": 0.041, - "apriori_support": 0.22 - }, - ... - ] - } """ + try: + input_data = _validate_table_id( + input_data, project_id=project_id, argument_name="input_data" + ) + except ValueError as ex: + return {"status": "ERROR", "error_details": str(ex)} + if not all(isinstance(item, str) for item in dimension_id_cols): return { "status": "ERROR", @@ -1059,15 +1120,14 @@ def analyze_contribution( f"contribution_analysis_model_{str(uuid.uuid4()).replace('-', '_')}" ) - id_cols_str = "[" + ", ".join([f"'{col}'" for col in dimension_id_cols]) + "]" options = [ "MODEL_TYPE = 'CONTRIBUTION_ANALYSIS'", - f"CONTRIBUTION_METRIC = '{contribution_metric}'", - f"IS_TEST_COL = '{is_test_col}'", - f"DIMENSION_ID_COLS = {id_cols_str}", + "CONTRIBUTION_METRIC = @contribution_metric", + "IS_TEST_COL = @is_test_col", + "DIMENSION_ID_COLS = @dimension_id_cols", ] - options.append(f"TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = {top_k_insights}") + options.append("TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = @top_k_insights") upper_pruning = pruning_method.upper() if upper_pruning not in ["NO_PRUNING", "PRUNE_REDUNDANT_INSIGHTS"]: @@ -1075,17 +1135,21 @@ def analyze_contribution( "status": "ERROR", "error_details": f"Invalid pruning_method: {pruning_method}", } - options.append(f"PRUNING_METHOD = '{upper_pruning}'") + options.append("PRUNING_METHOD = @pruning_method") options_str = ", ".join(options) - - trimmed_upper_input_data = input_data.strip().upper() - if trimmed_upper_input_data.startswith( - "SELECT" - ) or trimmed_upper_input_data.startswith("WITH"): - input_data_source = f"({input_data})" - else: - input_data_source = f"SELECT * FROM `{input_data}`" + input_data_source = f"SELECT * FROM `{input_data}`" + query_parameters = [ + bigquery.ScalarQueryParameter( + "contribution_metric", "STRING", contribution_metric + ), + bigquery.ScalarQueryParameter("is_test_col", "STRING", is_test_col), + bigquery.ArrayQueryParameter( + "dimension_id_cols", "STRING", dimension_id_cols + ), + bigquery.ScalarQueryParameter("top_k_insights", "INT64", top_k_insights), + bigquery.ScalarQueryParameter("pruning_method", "STRING", upper_pruning), + ] create_model_query = f""" CREATE TEMP MODEL {model_name} @@ -1117,6 +1181,7 @@ def analyze_contribution( settings=execute_sql_settings, tool_context=tool_context, caller_id="analyze_contribution", + query_parameters=query_parameters, ) if result["status"] != "SUCCESS": return result @@ -1157,18 +1222,16 @@ def detect_anomalies( Args: project_id (str): The GCP project id in which the query should be executed. - history_data (str): The table id of the BigQuery table containing the - history time series data or a query statement that select the history - data. + history_data (str): The table ID of the BigQuery table containing the + history time series data. times_series_timestamp_col (str): The name of the column containing the timestamp for each data point. times_series_data_col (str): The name of the column containing the numerical values to be forecasted and anomaly detected. horizon (int, optional): The number of time steps to forecast into the future. Defaults to 1000. - target_data (str, optional): The table id of the BigQuery table containing - the target time series data or a query statement that select the target - data. + target_data (str, optional): The table ID of the BigQuery table + containing the target time series data. times_series_id_cols (list, optional): The column names of the id columns to indicate each time series when there are multiple time series in the table. All elements must be strings. Defaults to None. @@ -1210,16 +1273,11 @@ def detect_anomalies( ] } - Detect Anomalies on multiple time series using a SQL query as input: + Detect Anomalies on multiple time series using a BigQuery table: - >>> history_query = ( - ... "SELECT unique_id, timestamp, value " - ... "FROM `my-project.my-dataset.my-timeseries-table` " - ... "WHERE timestamp > '1980-01-01'" - ... ) >>> detect_anomalies( ... project_id="my-gcp-project", - ... history_data=history_query, + ... history_data="my-project.my_dataset.my_timeseries_table", ... times_series_timestamp_col="timestamp", ... times_series_data_col="value", ... times_series_id_cols=["unique_id"] @@ -1271,20 +1329,38 @@ def detect_anomalies( location US" } """ - trimmed_upper_history_data = history_data.strip().upper() - if trimmed_upper_history_data.startswith( - "SELECT" - ) or trimmed_upper_history_data.startswith("WITH"): - history_data_source = f"({history_data})" - else: - history_data_source = f"SELECT * FROM `{history_data}`" + try: + history_data = _validate_table_id( + history_data, project_id=project_id, argument_name="history_data" + ) + except ValueError as ex: + return {"status": "ERROR", "error_details": str(ex)} + + history_data_source = f"SELECT * FROM `{history_data}`" options = [ "MODEL_TYPE = 'ARIMA_PLUS'", - f"TIME_SERIES_TIMESTAMP_COL = '{times_series_timestamp_col}'", - f"TIME_SERIES_DATA_COL = '{times_series_data_col}'", - f"HORIZON = {horizon}", + "TIME_SERIES_TIMESTAMP_COL = @times_series_timestamp_col", + "TIME_SERIES_DATA_COL = @times_series_data_col", + "HORIZON = @horizon", + ] + create_model_parameters = [ + bigquery.ScalarQueryParameter( + "times_series_timestamp_col", "STRING", times_series_timestamp_col + ), + bigquery.ScalarQueryParameter( + "times_series_data_col", "STRING", times_series_data_col + ), + bigquery.ScalarQueryParameter("horizon", "INT64", horizon), ] + try: + quoted_timestamp_col = _quote_field_path( + times_series_timestamp_col, + argument_name="times_series_timestamp_col", + ) + except ValueError as ex: + return {"status": "ERROR", "error_details": str(ex)} + order_by_cols = [quoted_timestamp_col] if times_series_id_cols: if not all(isinstance(item, str) for item in times_series_id_cols): @@ -1294,10 +1370,28 @@ def detect_anomalies( "All elements in times_series_id_cols must be strings." ), } - times_series_id_cols_str = ( - "[" + ", ".join([f"'{col}'" for col in times_series_id_cols]) + "]" + try: + quoted_id_cols = [ + _quote_field_path( + col, argument_name="times_series_id_cols" + ) + for col in times_series_id_cols + ] + except ValueError: + return { + "status": "ERROR", + "error_details": ( + "All elements in times_series_id_cols must be valid BigQuery" + " column names or field paths." + ), + } + options.append("TIME_SERIES_ID_COL = @times_series_id_cols") + create_model_parameters.append( + bigquery.ArrayQueryParameter( + "times_series_id_cols", "STRING", times_series_id_cols + ) ) - options.append(f"TIME_SERIES_ID_COL = {times_series_id_cols_str}") + order_by_cols = quoted_id_cols + order_by_cols options_str = ", ".join(options) @@ -1308,26 +1402,27 @@ def detect_anomalies( OPTIONS ({options_str}) AS {history_data_source} """ - order_by_id_cols = ( - ", ".join(col for col in times_series_id_cols) + ", " - if times_series_id_cols - else "" - ) + order_by_clause = ", ".join(order_by_cols) + anomaly_detection_parameters = [ + bigquery.ScalarQueryParameter( + "anomaly_prob_threshold", "FLOAT64", anomaly_prob_threshold + ) + ] anomaly_detection_query = f""" - SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT({anomaly_prob_threshold} AS anomaly_prob_threshold)) ORDER BY {order_by_id_cols}{times_series_timestamp_col} + SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT(@anomaly_prob_threshold AS anomaly_prob_threshold)) ORDER BY {order_by_clause} """ if target_data: - trimmed_upper_target_data = target_data.strip().upper() - if trimmed_upper_target_data.startswith( - "SELECT" - ) or trimmed_upper_target_data.startswith("WITH"): - target_data_source = f"({target_data})" - else: - target_data_source = f"(SELECT * FROM `{target_data}`)" + try: + target_data = _validate_table_id( + target_data, project_id=project_id, argument_name="target_data" + ) + except ValueError as ex: + return {"status": "ERROR", "error_details": str(ex)} + target_data_source = f"(SELECT * FROM `{target_data}`)" anomaly_detection_query = f""" - SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT({anomaly_prob_threshold} AS anomaly_prob_threshold), {target_data_source}) ORDER BY {order_by_id_cols}{times_series_timestamp_col} + SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT(@anomaly_prob_threshold AS anomaly_prob_threshold), {target_data_source}) ORDER BY {order_by_clause} """ # Create a session and run the create model query. @@ -1350,6 +1445,7 @@ def detect_anomalies( settings=execute_sql_settings, tool_context=tool_context, caller_id="detect_anomalies", + query_parameters=create_model_parameters, ) if result["status"] != "SUCCESS": return result @@ -1361,6 +1457,7 @@ def detect_anomalies( settings=execute_sql_settings, tool_context=tool_context, caller_id="detect_anomalies", + query_parameters=anomaly_detection_parameters, ) except Exception as ex: # pylint: disable=broad-except return { diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index 150cdb7569..26e6ba07fc 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -66,6 +66,10 @@ async def get_tool( return tools[0] +def _query_parameters_to_api_repr(query_parameters): + return [parameter.to_api_repr() for parameter in query_parameters] + + @pytest.mark.parametrize( ("tool_settings",), [ @@ -1231,35 +1235,73 @@ def test_forecast_with_table_id(mock_execute_sql): expected_query = """ SELECT * FROM AI.FORECAST( TABLE `test-dataset.test-table`, - data_col => 'data_col', - timestamp_col => 'ts_col', - model => 'TimesFM 2.0', - id_cols => ['id1', 'id2'], - horizon => 20, - confidence_level => 0.95 + data_col => @data_col, + timestamp_col => @timestamp_col, + model => @model, + id_cols => @id_cols, + horizon => @horizon, + confidence_level => @confidence_level ) """ - mock_execute_sql.assert_called_once_with( - project_id="test-project", - query=expected_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="forecast", - ) + mock_execute_sql.assert_called_once() + assert mock_execute_sql.call_args.kwargs["project_id"] == "test-project" + assert mock_execute_sql.call_args.kwargs["query"] == expected_query + assert mock_execute_sql.call_args.kwargs["credentials"] == mock_credentials + assert mock_execute_sql.call_args.kwargs["settings"] == mock_settings + assert mock_execute_sql.call_args.kwargs["tool_context"] == mock_tool_context + assert mock_execute_sql.call_args.kwargs["caller_id"] == "forecast" + assert _query_parameters_to_api_repr( + mock_execute_sql.call_args.kwargs["query_parameters"] + ) == [ + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "data_col"}, + "name": "data_col", + }, + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_col"}, + "name": "timestamp_col", + }, + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "TimesFM 2.0"}, + "name": "model", + }, + { + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "20"}, + "name": "horizon", + }, + { + "parameterType": {"type": "FLOAT64"}, + "parameterValue": {"value": 0.95}, + "name": "confidence_level", + }, + { + "parameterType": { + "type": "ARRAY", + "arrayType": {"type": "STRING"}, + }, + "parameterValue": { + "arrayValues": [{"value": "id1"}, {"value": "id2"}] + }, + "name": "id_cols", + }, + ] # AI.Forecast calls _execute_sql with a specific query statement. We need to # test that the query is properly constructed and call _execute_sql with the # correct parameters exactly once. @mock.patch.object(query_tool, "_execute_sql", autospec=True) -def test_forecast_with_query_statement(mock_execute_sql): +def test_forecast_rejects_query_statement_history_data(mock_execute_sql): mock_credentials = mock.MagicMock(spec=Credentials) mock_settings = BigQueryToolConfig() mock_tool_context = mock.create_autospec(ToolContext, instance=True) history_data_query = "SELECT * FROM `test-dataset.test-table`" - query_tool.forecast( + result = query_tool.forecast( project_id="test-project", history_data=history_data_query, timestamp_col="ts_col", @@ -1269,24 +1311,14 @@ def test_forecast_with_query_statement(mock_execute_sql): tool_context=mock_tool_context, ) - expected_query = f""" - SELECT * FROM AI.FORECAST( - ({history_data_query}), - data_col => 'data_col', - timestamp_col => 'ts_col', - model => 'TimesFM 2.0', - horizon => 10, - confidence_level => 0.95 - ) - """ - mock_execute_sql.assert_called_once_with( - project_id="test-project", - query=expected_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="forecast", - ) + assert result == { + "status": "ERROR", + "error_details": ( + "history_data must be a BigQuery table ID. SQL query statements are" + " not supported by this tool." + ), + } + mock_execute_sql.assert_not_called() def test_forecast_with_invalid_id_cols(): @@ -1334,7 +1366,7 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql): expected_create_model_query = """ CREATE TEMP MODEL contribution_analysis_model_test_uuid - OPTIONS (MODEL_TYPE = 'CONTRIBUTION_ANALYSIS', CONTRIBUTION_METRIC = 'SUM(metric)', IS_TEST_COL = 'is_test', DIMENSION_ID_COLS = ['dim1', 'dim2'], TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = 30, PRUNING_METHOD = 'PRUNE_REDUNDANT_INSIGHTS') + OPTIONS (MODEL_TYPE = 'CONTRIBUTION_ANALYSIS', CONTRIBUTION_METRIC = @contribution_metric, IS_TEST_COL = @is_test_col, DIMENSION_ID_COLS = @dimension_id_cols, TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = @top_k_insights, PRUNING_METHOD = @pruning_method) AS SELECT * FROM `test-dataset.test-table` """ @@ -1343,22 +1375,53 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql): """ assert mock_execute_sql.call_count == 2 - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_create_model_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="analyze_contribution", - ) - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_get_insights_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="analyze_contribution", - ) + create_call = mock_execute_sql.call_args_list[0].kwargs + assert create_call["project_id"] == "test-project" + assert create_call["query"] == expected_create_model_query + assert create_call["credentials"] == mock_credentials + assert create_call["settings"] == mock_settings + assert create_call["tool_context"] == mock_tool_context + assert create_call["caller_id"] == "analyze_contribution" + assert _query_parameters_to_api_repr(create_call["query_parameters"]) == [ + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "SUM(metric)"}, + "name": "contribution_metric", + }, + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "is_test"}, + "name": "is_test_col", + }, + { + "parameterType": { + "type": "ARRAY", + "arrayType": {"type": "STRING"}, + }, + "parameterValue": { + "arrayValues": [{"value": "dim1"}, {"value": "dim2"}] + }, + "name": "dimension_id_cols", + }, + { + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "30"}, + "name": "top_k_insights", + }, + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "PRUNE_REDUNDANT_INSIGHTS"}, + "name": "pruning_method", + }, + ] + insights_call = mock_execute_sql.call_args_list[1].kwargs + assert insights_call["project_id"] == "test-project" + assert insights_call["query"] == expected_get_insights_query + assert insights_call["credentials"] == mock_credentials + assert insights_call["settings"] == mock_settings + assert insights_call["tool_context"] == mock_tool_context + assert insights_call["caller_id"] == "analyze_contribution" + assert "query_parameters" not in insights_call # analyze_contribution calls _execute_sql twice. We need to test that the @@ -1366,15 +1429,17 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql): # parameters exactly twice. @mock.patch.object(query_tool, "_execute_sql", autospec=True) @mock.patch.object(uuid, "uuid4", autospec=True) -def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql): - """Test analyze_contribution tool invocation with a query statement.""" +def test_analyze_contribution_rejects_query_statement( + mock_uuid, mock_execute_sql +): + """Test analyze_contribution rejects query statements as input_data.""" mock_credentials = mock.MagicMock(spec=Credentials) mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} input_data_query = "SELECT * FROM `test-dataset.test-table`" - query_tool.analyze_contribution( + result = query_tool.analyze_contribution( project_id="test-project", input_data=input_data_query, dimension_id_cols=["dim1", "dim2"], @@ -1385,33 +1450,14 @@ def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql): tool_context=mock_tool_context, ) - expected_create_model_query = f""" - CREATE TEMP MODEL contribution_analysis_model_test_uuid - OPTIONS (MODEL_TYPE = 'CONTRIBUTION_ANALYSIS', CONTRIBUTION_METRIC = 'SUM(metric)', IS_TEST_COL = 'is_test', DIMENSION_ID_COLS = ['dim1', 'dim2'], TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = 30, PRUNING_METHOD = 'PRUNE_REDUNDANT_INSIGHTS') - AS ({input_data_query}) - """ - - expected_get_insights_query = """ - SELECT * FROM ML.GET_INSIGHTS(MODEL contribution_analysis_model_test_uuid) - """ - - assert mock_execute_sql.call_count == 2 - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_create_model_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="analyze_contribution", - ) - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_get_insights_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="analyze_contribution", - ) + assert result == { + "status": "ERROR", + "error_details": ( + "input_data must be a BigQuery table ID. SQL query statements are" + " not supported by this tool." + ), + } + mock_execute_sql.assert_not_called() def test_analyze_contribution_with_invalid_dimension_id_cols(): @@ -1450,10 +1496,9 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - history_data_query = "SELECT * FROM `test-dataset.test-table`" query_tool.detect_anomalies( project_id="test-project", - history_data=history_data_query, + history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", credentials=mock_credentials, @@ -1463,31 +1508,53 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql): expected_create_model_query = """ CREATE TEMP MODEL detect_anomalies_model_test_uuid - OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = 'ts_timestamp', TIME_SERIES_DATA_COL = 'ts_data', HORIZON = 1000) - AS (SELECT * FROM `test-dataset.test-table`) + OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = @times_series_timestamp_col, TIME_SERIES_DATA_COL = @times_series_data_col, HORIZON = @horizon) + AS SELECT * FROM `test-dataset.test-table` """ expected_anomaly_detection_query = """ - SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.95 AS anomaly_prob_threshold)) ORDER BY ts_timestamp + SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(@anomaly_prob_threshold AS anomaly_prob_threshold)) ORDER BY `ts_timestamp` """ assert mock_execute_sql.call_count == 2 - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_create_model_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="detect_anomalies", - ) - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_anomaly_detection_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="detect_anomalies", - ) + create_call = mock_execute_sql.call_args_list[0].kwargs + assert create_call["project_id"] == "test-project" + assert create_call["query"] == expected_create_model_query + assert create_call["credentials"] == mock_credentials + assert create_call["settings"] == mock_settings + assert create_call["tool_context"] == mock_tool_context + assert create_call["caller_id"] == "detect_anomalies" + assert _query_parameters_to_api_repr(create_call["query_parameters"]) == [ + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_timestamp"}, + "name": "times_series_timestamp_col", + }, + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_data"}, + "name": "times_series_data_col", + }, + { + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "1000"}, + "name": "horizon", + }, + ] + detect_call = mock_execute_sql.call_args_list[1].kwargs + assert detect_call["project_id"] == "test-project" + assert detect_call["query"] == expected_anomaly_detection_query + assert detect_call["credentials"] == mock_credentials + assert detect_call["settings"] == mock_settings + assert detect_call["tool_context"] == mock_tool_context + assert detect_call["caller_id"] == "detect_anomalies" + assert _query_parameters_to_api_repr(detect_call["query_parameters"]) == [ + { + "parameterType": {"type": "FLOAT64"}, + "parameterValue": {"value": 0.95}, + "name": "anomaly_prob_threshold", + } + ] # detect_anomalies calls _execute_sql twice. We need to test that @@ -1502,10 +1569,9 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - history_data_query = "SELECT * FROM `test-dataset.test-table`" query_tool.detect_anomalies( project_id="test-project", - history_data=history_data_query, + history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", times_series_id_cols=["dim1", "dim2"], @@ -1518,31 +1584,63 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql): expected_create_model_query = """ CREATE TEMP MODEL detect_anomalies_model_test_uuid - OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = 'ts_timestamp', TIME_SERIES_DATA_COL = 'ts_data', HORIZON = 20, TIME_SERIES_ID_COL = ['dim1', 'dim2']) - AS (SELECT * FROM `test-dataset.test-table`) + OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = @times_series_timestamp_col, TIME_SERIES_DATA_COL = @times_series_data_col, HORIZON = @horizon, TIME_SERIES_ID_COL = @times_series_id_cols) + AS SELECT * FROM `test-dataset.test-table` """ expected_anomaly_detection_query = """ - SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.8 AS anomaly_prob_threshold)) ORDER BY dim1, dim2, ts_timestamp + SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(@anomaly_prob_threshold AS anomaly_prob_threshold)) ORDER BY `dim1`, `dim2`, `ts_timestamp` """ assert mock_execute_sql.call_count == 2 - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_create_model_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="detect_anomalies", - ) - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_anomaly_detection_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="detect_anomalies", - ) + create_call = mock_execute_sql.call_args_list[0].kwargs + assert create_call["project_id"] == "test-project" + assert create_call["query"] == expected_create_model_query + assert create_call["credentials"] == mock_credentials + assert create_call["settings"] == mock_settings + assert create_call["tool_context"] == mock_tool_context + assert create_call["caller_id"] == "detect_anomalies" + assert _query_parameters_to_api_repr(create_call["query_parameters"]) == [ + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_timestamp"}, + "name": "times_series_timestamp_col", + }, + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_data"}, + "name": "times_series_data_col", + }, + { + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "20"}, + "name": "horizon", + }, + { + "parameterType": { + "type": "ARRAY", + "arrayType": {"type": "STRING"}, + }, + "parameterValue": { + "arrayValues": [{"value": "dim1"}, {"value": "dim2"}] + }, + "name": "times_series_id_cols", + }, + ] + detect_call = mock_execute_sql.call_args_list[1].kwargs + assert detect_call["project_id"] == "test-project" + assert detect_call["query"] == expected_anomaly_detection_query + assert detect_call["credentials"] == mock_credentials + assert detect_call["settings"] == mock_settings + assert detect_call["tool_context"] == mock_tool_context + assert detect_call["caller_id"] == "detect_anomalies" + assert _query_parameters_to_api_repr(detect_call["query_parameters"]) == [ + { + "parameterType": {"type": "FLOAT64"}, + "parameterValue": {"value": 0.8}, + "name": "anomaly_prob_threshold", + } + ] # detect_anomalies calls _execute_sql twice. We need to test that @@ -1557,16 +1655,14 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - history_data_query = "SELECT * FROM `test-dataset.history-table`" - target_data_query = "SELECT * FROM `test-dataset.target-table`" query_tool.detect_anomalies( project_id="test-project", - history_data=history_data_query, + history_data="test-dataset.history-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", times_series_id_cols=["dim1", "dim2"], horizon=20, - target_data=target_data_query, + target_data="test-dataset.target-table", anomaly_prob_threshold=0.8, credentials=mock_credentials, settings=mock_settings, @@ -1575,31 +1671,53 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql): expected_create_model_query = """ CREATE TEMP MODEL detect_anomalies_model_test_uuid - OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = 'ts_timestamp', TIME_SERIES_DATA_COL = 'ts_data', HORIZON = 20, TIME_SERIES_ID_COL = ['dim1', 'dim2']) - AS (SELECT * FROM `test-dataset.history-table`) + OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = @times_series_timestamp_col, TIME_SERIES_DATA_COL = @times_series_data_col, HORIZON = @horizon, TIME_SERIES_ID_COL = @times_series_id_cols) + AS SELECT * FROM `test-dataset.history-table` """ expected_anomaly_detection_query = """ - SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.8 AS anomaly_prob_threshold), (SELECT * FROM `test-dataset.target-table`)) ORDER BY dim1, dim2, ts_timestamp + SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(@anomaly_prob_threshold AS anomaly_prob_threshold), (SELECT * FROM `test-dataset.target-table`)) ORDER BY `dim1`, `dim2`, `ts_timestamp` """ assert mock_execute_sql.call_count == 2 - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_create_model_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="detect_anomalies", - ) - mock_execute_sql.assert_any_call( - project_id="test-project", - query=expected_anomaly_detection_query, - credentials=mock_credentials, - settings=mock_settings, - tool_context=mock_tool_context, - caller_id="detect_anomalies", - ) + create_call = mock_execute_sql.call_args_list[0].kwargs + assert create_call["query"] == expected_create_model_query + assert _query_parameters_to_api_repr(create_call["query_parameters"]) == [ + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_timestamp"}, + "name": "times_series_timestamp_col", + }, + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_data"}, + "name": "times_series_data_col", + }, + { + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "20"}, + "name": "horizon", + }, + { + "parameterType": { + "type": "ARRAY", + "arrayType": {"type": "STRING"}, + }, + "parameterValue": { + "arrayValues": [{"value": "dim1"}, {"value": "dim2"}] + }, + "name": "times_series_id_cols", + }, + ] + detect_call = mock_execute_sql.call_args_list[1].kwargs + assert detect_call["query"] == expected_anomaly_detection_query + assert _query_parameters_to_api_repr(detect_call["query_parameters"]) == [ + { + "parameterType": {"type": "FLOAT64"}, + "parameterValue": {"value": 0.8}, + "name": "anomaly_prob_threshold", + } + ] # detect_anomalies calls execute_sql twice. We need to test that @@ -1614,10 +1732,9 @@ def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - history_data_query = "SELECT * FROM `test-dataset.test-table`" query_tool.detect_anomalies( project_id="test-project", - history_data=history_data_query, + history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", target_data="test-dataset.target-table", @@ -1628,32 +1745,97 @@ def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql): expected_create_model_query = """ CREATE TEMP MODEL detect_anomalies_model_test_uuid - OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = 'ts_timestamp', TIME_SERIES_DATA_COL = 'ts_data', HORIZON = 1000) - AS (SELECT * FROM `test-dataset.test-table`) + OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = @times_series_timestamp_col, TIME_SERIES_DATA_COL = @times_series_data_col, HORIZON = @horizon) + AS SELECT * FROM `test-dataset.test-table` """ expected_anomaly_detection_query = """ - SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.95 AS anomaly_prob_threshold), (SELECT * FROM `test-dataset.target-table`)) ORDER BY ts_timestamp + SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(@anomaly_prob_threshold AS anomaly_prob_threshold), (SELECT * FROM `test-dataset.target-table`)) ORDER BY `ts_timestamp` """ assert mock_execute_sql.call_count == 2 - mock_execute_sql.assert_any_call( + create_call = mock_execute_sql.call_args_list[0].kwargs + assert create_call["query"] == expected_create_model_query + assert _query_parameters_to_api_repr(create_call["query_parameters"]) == [ + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_timestamp"}, + "name": "times_series_timestamp_col", + }, + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "ts_data"}, + "name": "times_series_data_col", + }, + { + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "1000"}, + "name": "horizon", + }, + ] + detect_call = mock_execute_sql.call_args_list[1].kwargs + assert detect_call["query"] == expected_anomaly_detection_query + assert _query_parameters_to_api_repr(detect_call["query_parameters"]) == [ + { + "parameterType": {"type": "FLOAT64"}, + "parameterValue": {"value": 0.95}, + "name": "anomaly_prob_threshold", + } + ] + + +@mock.patch.object(query_tool, "_execute_sql", autospec=True) +def test_detect_anomalies_rejects_query_statement_history_data(mock_execute_sql): + mock_credentials = mock.MagicMock(spec=Credentials) + mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) + mock_tool_context = mock.create_autospec(ToolContext, instance=True) + + result = query_tool.detect_anomalies( project_id="test-project", - query=expected_create_model_query, + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", credentials=mock_credentials, settings=mock_settings, tool_context=mock_tool_context, - caller_id="detect_anomalies", ) - mock_execute_sql.assert_any_call( + + assert result == { + "status": "ERROR", + "error_details": ( + "history_data must be a BigQuery table ID. SQL query statements are" + " not supported by this tool." + ), + } + mock_execute_sql.assert_not_called() + + +@mock.patch.object(query_tool, "_execute_sql", autospec=True) +def test_detect_anomalies_rejects_query_statement_target_data(mock_execute_sql): + mock_credentials = mock.MagicMock(spec=Credentials) + mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) + mock_tool_context = mock.create_autospec(ToolContext, instance=True) + + result = query_tool.detect_anomalies( project_id="test-project", - query=expected_anomaly_detection_query, + history_data="test-dataset.history-table", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + target_data="SELECT * FROM `test-dataset.target-table`", credentials=mock_credentials, settings=mock_settings, tool_context=mock_tool_context, - caller_id="detect_anomalies", ) + assert result == { + "status": "ERROR", + "error_details": ( + "target_data must be a BigQuery table ID. SQL query statements are" + " not supported by this tool." + ), + } + mock_execute_sql.assert_not_called() + def test_detect_anomalies_with_invalid_id_cols(): """Test time series anomaly detection tool invocation with invalid times_series_id_cols.""" @@ -1679,6 +1861,32 @@ def test_detect_anomalies_with_invalid_id_cols(): ) +def test_detect_anomalies_with_invalid_id_field_path(): + """Test detect_anomalies rejects unsafe ORDER BY identifiers.""" + mock_credentials = mock.MagicMock(spec=Credentials) + mock_settings = BigQueryToolConfig() + mock_tool_context = mock.create_autospec(ToolContext, instance=True) + + result = query_tool.detect_anomalies( + project_id="test-project", + history_data="test-dataset.test-table", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + times_series_id_cols=["(SELECT password FROM admin_users LIMIT 1)"], + credentials=mock_credentials, + settings=mock_settings, + tool_context=mock_tool_context, + ) + + assert result == { + "status": "ERROR", + "error_details": ( + "All elements in times_series_id_cols must be valid BigQuery column" + " names or field paths." + ), + } + + @pytest.mark.parametrize( ("write_mode", "dry_run", "query_call_count", "query_and_wait_call_count"), [ @@ -1801,7 +2009,7 @@ def test_execute_sql_user_job_labels_augment_internal_labels( pytest.param( lambda tool_context: query_tool.forecast( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", timestamp_col="ts_col", data_col="data_col", credentials=mock.create_autospec(Credentials, instance=True), @@ -1828,7 +2036,7 @@ def test_execute_sql_user_job_labels_augment_internal_labels( pytest.param( lambda tool_context: query_tool.detect_anomalies( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", credentials=mock.create_autospec(Credentials, instance=True), @@ -1867,7 +2075,7 @@ def test_ml_tool_job_labels(tool_call, expected_tool_label): pytest.param( lambda tool_context: query_tool.forecast( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", timestamp_col="ts_col", data_col="data_col", credentials=mock.create_autospec(Credentials, instance=True), @@ -1898,7 +2106,7 @@ def test_ml_tool_job_labels(tool_call, expected_tool_label): pytest.param( lambda tool_context: query_tool.detect_anomalies( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", credentials=mock.create_autospec(Credentials, instance=True), @@ -1942,7 +2150,7 @@ def test_ml_tool_job_labels_w_application_name(tool_call, expected_tool_label): pytest.param( lambda tool_context: query_tool.forecast( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", timestamp_col="ts_col", data_col="data_col", credentials=mock.create_autospec(Credentials, instance=True), @@ -1983,7 +2191,7 @@ def test_ml_tool_job_labels_w_application_name(tool_call, expected_tool_label): pytest.param( lambda tool_context: query_tool.detect_anomalies( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", credentials=mock.create_autospec(Credentials, instance=True), @@ -2122,7 +2330,7 @@ def test_execute_sql_maximum_bytes_billed_config(): pytest.param( lambda settings, tool_context: query_tool.forecast( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", timestamp_col="ts_col", data_col="data_col", credentials=mock.create_autospec(Credentials, instance=True), @@ -2147,7 +2355,7 @@ def test_execute_sql_maximum_bytes_billed_config(): pytest.param( lambda settings, tool_context: query_tool.detect_anomalies( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", credentials=mock.create_autospec(Credentials, instance=True), @@ -2206,7 +2414,7 @@ def test_tool_call_doesnt_change_global_settings(tool_call): pytest.param( lambda settings, tool_context: query_tool.forecast( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", timestamp_col="ts_col", data_col="data_col", credentials=mock.create_autospec(Credentials, instance=True), @@ -2231,7 +2439,7 @@ def test_tool_call_doesnt_change_global_settings(tool_call): pytest.param( lambda settings, tool_context: query_tool.detect_anomalies( project_id="test-project", - history_data="SELECT * FROM `test-dataset.test-table`", + history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", times_series_data_col="ts_data", credentials=mock.create_autospec(Credentials, instance=True),