Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions bigframes/bigquery/_operations/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,63 @@ def generate_text(
return bpd.read_gbq_query(sql)
else:
return session.read_gbq_query(sql)


@log_adapter.method_logger(custom_base_name="bigquery_ml")
def generate_embedding(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
*,
flatten_json_output: Optional[bool] = None,
task_type: Optional[str] = None,
output_dimensionality: Optional[int] = None,
) -> dataframe.DataFrame:
"""
Generates text embedding using a BigQuery ML model.

See the `BigQuery ML GENERATE_EMBEDDING function syntax
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-embedding>`_
for additional reference.

Args:
model (bigframes.ml.base.BaseEstimator or str):
The model to use for text embedding.
input_ (Union[bigframes.pandas.DataFrame, str]):
The DataFrame or query to use for text embedding.
flatten_json_output (bool, optional):
A BOOL value that determines the content of the generated JSON column.
task_type (str, optional):
A STRING value that specifies the intended downstream application task.
Supported values are:
- `RETRIEVAL_QUERY`
- `RETRIEVAL_DOCUMENT`
- `SEMANTIC_SIMILARITY`
- `CLASSIFICATION`
- `CLUSTERING`
- `QUESTION_ANSWERING`
- `FACT_VERIFICATION`
- `CODE_RETRIEVAL_QUERY`
output_dimensionality (int, optional):
An INT64 value that specifies the size of the output embedding.

Returns:
bigframes.pandas.DataFrame:
The generated text embedding.
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model, input_)
table_sql = _to_sql(input_)

sql = bigframes.core.sql.ml.generate_embedding(
model_name=model_name,
table=table_sql,
flatten_json_output=flatten_json_output,
task_type=task_type,
output_dimensionality=output_dimensionality,
)

if session is None:
return bpd.read_gbq_query(sql)
else:
return session.read_gbq_query(sql)
2 changes: 2 additions & 0 deletions bigframes/bigquery/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
create_model,
evaluate,
explain_predict,
generate_embedding,
generate_text,
global_explain,
predict,
Expand All @@ -37,4 +38,5 @@
"global_explain",
"transform",
"generate_text",
"generate_embedding",
]
28 changes: 28 additions & 0 deletions bigframes/core/sql/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,31 @@ def generate_text(
sql += _build_struct_sql(struct_options)
sql += ")\n"
return sql


def generate_embedding(
model_name: str,
table: str,
*,
flatten_json_output: Optional[bool] = None,
task_type: Optional[str] = None,
output_dimensionality: Optional[int] = None,
) -> str:
"""Encode the ML.GENERATE_EMBEDDING statement.
See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-embedding for reference.
"""
struct_options: Dict[
str,
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
] = {}
if flatten_json_output is not None:
struct_options["flatten_json_output"] = flatten_json_output
if task_type is not None:
struct_options["task_type"] = task_type
if output_dimensionality is not None:
struct_options["output_dimensionality"] = output_dimensionality

sql = f"SELECT * FROM ML.GENERATE_EMBEDDING(MODEL {googlesql.identifier(model_name)}, ({table})"
sql += _build_struct_sql(struct_options)
sql += ")\n"
return sql
64 changes: 64 additions & 0 deletions tests/system/large/bigquery/test_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import bigframes.bigquery.ml as ml
import bigframes.pandas as bpd


@pytest.fixture(scope="session")
def embedding_model(bq_connection, dataset_id):
model_name = f"{dataset_id}.embedding_model"
return ml.create_model(
model_name=model_name,
options={"endpoint": "gemini-embedding-001"},
connection_name=bq_connection,
)


def test_generate_embedding(embedding_model):
df = bpd.DataFrame(
{
"content": [
"What is BigQuery?",
"What is BQML?",
]
}
)

result = ml.generate_embedding(embedding_model, df)
assert len(result) == 2
assert "ml_generate_embedding_result" in result.columns
assert "ml_generate_embedding_status" in result.columns


def test_generate_embedding_with_options(embedding_model):
df = bpd.DataFrame(
{
"content": [
"What is BigQuery?",
"What is BQML?",
]
}
)

result = ml.generate_embedding(
embedding_model, df, task_type="RETRIEVAL_DOCUMENT", output_dimensionality=256
)
assert len(result) == 2
assert "ml_generate_embedding_result" in result.columns
assert "ml_generate_embedding_status" in result.columns
embedding = result["ml_generate_embedding_result"].to_pandas()
assert len(embedding[0]) == 256
29 changes: 29 additions & 0 deletions tests/unit/bigquery/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,32 @@ def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mo
assert "['a', 'b'] AS stop_sequences" in generated_sql
assert "true AS ground_with_google_search" in generated_sql
assert "'TYPE' AS request_type" in generated_sql


@mock.patch("bigframes.pandas.read_gbq_query")
@mock.patch("bigframes.pandas.read_pandas")
def test_generate_embedding_with_pandas_dataframe(
read_pandas_mock, read_gbq_query_mock
):
df = pd.DataFrame({"col1": [1, 2, 3]})
read_pandas_mock.return_value._to_sql_query.return_value = (
"SELECT * FROM `pandas_df`",
[],
[],
)
ml_ops.generate_embedding(
MODEL_SERIES,
input_=df,
flatten_json_output=True,
task_type="RETRIEVAL_DOCUMENT",
output_dimensionality=256,
)
read_pandas_mock.assert_called_once()
read_gbq_query_mock.assert_called_once()
generated_sql = read_gbq_query_mock.call_args[0][0]
assert "ML.GENERATE_EMBEDDING" in generated_sql
assert f"MODEL `{MODEL_NAME}`" in generated_sql
assert "(SELECT * FROM `pandas_df`)" in generated_sql
assert "true AS flatten_json_output" in generated_sql
assert "'RETRIEVAL_DOCUMENT' AS task_type" in generated_sql
assert "256 AS output_dimensionality" in generated_sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(true AS flatten_json_output, 'RETRIEVAL_DOCUMENT' AS task_type, 256 AS output_dimensionality))
19 changes: 19 additions & 0 deletions tests/unit/core/sql/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,22 @@ def test_generate_text_model_with_options(snapshot):
request_type="TYPE",
)
snapshot.assert_match(sql, "generate_text_model_with_options.sql")


def test_generate_embedding_model_basic(snapshot):
sql = bigframes.core.sql.ml.generate_embedding(
model_name="my_project.my_dataset.my_model",
table="SELECT * FROM new_data",
)
snapshot.assert_match(sql, "generate_embedding_model_basic.sql")


def test_generate_embedding_model_with_options(snapshot):
sql = bigframes.core.sql.ml.generate_embedding(
model_name="my_project.my_dataset.my_model",
table="SELECT * FROM new_data",
flatten_json_output=True,
task_type="RETRIEVAL_DOCUMENT",
output_dimensionality=256,
)
snapshot.assert_match(sql, "generate_embedding_model_with_options.sql")
Loading