Skip to content

Commit ca47835

Browse files
tswastgoogle-labs-jules[bot]gemini-code-assist[bot]
authored
feat: add support for hparam_range and hparam_candidates to bigframes.bigquery.create_model (#16640)
Fixes internal issue b/501171054 🦕 --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: tswast <247555+tswast@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent fe5245b commit ca47835

5 files changed

Lines changed: 88 additions & 16 deletions

File tree

packages/bigframes/bigframes/bigquery/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@
8787
to_json,
8888
to_json_string,
8989
)
90-
from bigframes.bigquery._operations.mathematical import rand
90+
from bigframes.bigquery._operations.mathematical import (
91+
hparam_candidates,
92+
hparam_range,
93+
rand,
94+
)
9195
from bigframes.bigquery._operations.search import create_vector_index, vector_search
9296
from bigframes.bigquery._operations.sql import sql_scalar
9397
from bigframes.bigquery._operations.struct import struct
@@ -130,6 +134,8 @@
130134
to_json,
131135
to_json_string,
132136
# mathematical ops
137+
hparam_candidates,
138+
hparam_range,
133139
rand,
134140
# search ops
135141
create_vector_index,
@@ -187,6 +193,8 @@
187193
"to_json",
188194
"to_json_string",
189195
# mathematical ops
196+
"hparam_candidates",
197+
"hparam_range",
190198
"rand",
191199
# search ops
192200
"create_vector_index",

packages/bigframes/bigframes/bigquery/_operations/mathematical.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Sequence
18+
1719
import bigframes.core.col
1820
import bigframes.core.expression
1921
from bigframes import dtypes
@@ -51,3 +53,71 @@ def rand() -> bigframes.core.col.Expression:
5153
is_deterministic=False,
5254
)
5355
return bigframes.core.col.Expression(bigframes.core.expression.OpExpression(op, ()))
56+
57+
58+
def hparam_range(min: float, max: float) -> bigframes.core.col.Expression:
59+
"""
60+
Defines the minimum and maximum bounds of the search space of continuous
61+
values for a hyperparameter.
62+
63+
**Examples:**
64+
65+
>>> import bigframes.pandas as bpd
66+
>>> import bigframes.bigquery as bbq
67+
>>> # Specify a range of values for a hyperparameter.
68+
>>> learn_rate = bbq.hparam_range(0.0001, 1.0)
69+
70+
Args:
71+
min (float or int):
72+
The minimum bound of the search space.
73+
max (float or int):
74+
The maximum bound of the search space.
75+
76+
Returns:
77+
bigframes.pandas.api.typing.Expression:
78+
An expression that can be used in model options.
79+
"""
80+
min_expr = bigframes.core.expression.const(min)
81+
max_expr = bigframes.core.expression.const(max)
82+
83+
op = ops.SqlScalarOp(
84+
_output_type=dtypes.FLOAT_DTYPE,
85+
sql_template="HPARAM_RANGE({0}, {1})",
86+
is_deterministic=True,
87+
)
88+
return bigframes.core.col.Expression(
89+
bigframes.core.expression.OpExpression(op, (min_expr, max_expr))
90+
)
91+
92+
93+
def hparam_candidates(
94+
candidates: Sequence[float | str],
95+
) -> bigframes.core.col.Expression:
96+
"""
97+
Specifies the set of discrete values for the hyperparameter.
98+
99+
**Examples:**
100+
101+
>>> import bigframes.pandas as bpd
102+
>>> import bigframes.bigquery as bbq
103+
>>> # Specify a set of values for a hyperparameter.
104+
>>> optimizer = bbq.hparam_candidates(['ADAGRAD', 'SGD', 'FTRL'])
105+
106+
Args:
107+
candidates (Sequence[float | str]):
108+
The set of discrete values for the hyperparameter.
109+
110+
Returns:
111+
bigframes.pandas.api.typing.Expression:
112+
An expression that can be used in model options.
113+
"""
114+
candidates_expr = bigframes.core.expression.const(tuple(candidates))
115+
116+
op = ops.SqlScalarOp(
117+
_output_type=dtypes.STRING_DTYPE,
118+
sql_template="HPARAM_CANDIDATES({0})",
119+
is_deterministic=True,
120+
)
121+
return bigframes.core.col.Expression(
122+
bigframes.core.expression.OpExpression(op, (candidates_expr,))
123+
)

packages/bigframes/tests/unit/core/sql/snapshots/test_ml/test_create_model_expression_option/create_model_expression_option.sql

Lines changed: 0 additions & 3 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE MODEL `my_model`
2+
OPTIONS(model_type = 'LINEAR_REG', learn_rate = HPARAM_RANGE(0.0001, 1.0), optimizer = HPARAM_CANDIDATES(['ADAGRAD', 'SGD']))
3+
AS SELECT * FROM t

packages/bigframes/tests/unit/core/sql/test_ml.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616

17+
import bigframes.bigquery as bbq
1718
import bigframes.core.col as col
1819
import bigframes.core.expression as ex
1920
import bigframes.core.sql.ml
@@ -101,24 +102,17 @@ def test_create_model_list_option(snapshot):
101102
snapshot.assert_match(sql, "create_model_list_option.sql")
102103

103104

104-
def test_create_model_expression_option(snapshot):
105-
# An expression that calls a function on a literal value
106-
# e.g. 0.1 * 10
107-
literal_expr = ex.ScalarConstantExpression(0.1, dtypes.FLOAT_DTYPE)
108-
multiplier_expr = ex.ScalarConstantExpression(10, dtypes.INT_DTYPE)
109-
math_expr = col.Expression(
110-
ex.OpExpression(op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr))
111-
)
112-
105+
def test_create_model_hparam_tuning(snapshot):
113106
sql = bigframes.core.sql.ml.create_model_ddl(
114107
model_name="my_model",
115108
options={
116-
"l2_reg": math_expr,
117-
"booster_type": "gbtree",
109+
"model_type": "LINEAR_REG",
110+
"learn_rate": bbq.hparam_range(0.0001, 1.0),
111+
"optimizer": bbq.hparam_candidates(["ADAGRAD", "SGD"]),
118112
},
119113
training_data="SELECT * FROM t",
120114
)
121-
snapshot.assert_match(sql, "create_model_expression_option.sql")
115+
snapshot.assert_match(sql, "create_model_hparam_tuning.sql")
122116

123117

124118
def test_evaluate_model_basic(snapshot):

0 commit comments

Comments
 (0)