Skip to content
Permalink
Browse files
feat: Update proto definitions for bigquery/v2 to support new proto f…
…ields for BQML. (#817)

PiperOrigin-RevId: 387137741

Source-Link: googleapis/googleapis@8962c92

Source-Link: googleapis/googleapis-gen@102f1b4
  • Loading branch information
gcf-owl-bot committed Jul 27, 2021
1 parent 3c1be14 commit fe7a902e8b3e723ace335c9b499aea6d180a025b
Showing with 107 additions and 9 deletions.
  1. +95 −9 google/cloud/bigquery_v2/types/model.py
  2. +12 −0 google/cloud/bigquery_v2/types/table_reference.py
@@ -96,6 +96,8 @@ class Model(proto.Message):
Output only. Label columns that were used to train this
model. The output of the model will have a `predicted_`
prefix to these columns.
best_trial_id (int):
The best trial_id across all training runs.
"""

class ModelType(proto.Enum):
@@ -113,6 +115,7 @@ class ModelType(proto.Enum):
ARIMA = 11
AUTOML_REGRESSOR = 12
AUTOML_CLASSIFIER = 13
ARIMA_PLUS = 19

class LossType(proto.Enum):
r"""Loss metric to evaluate model training performance."""
@@ -151,6 +154,7 @@ class DataFrequency(proto.Enum):
WEEKLY = 5
DAILY = 6
HOURLY = 7
PER_MINUTE = 8

class HolidayRegion(proto.Enum):
r"""Type of supported holiday regions for time series forecasting
@@ -285,7 +289,7 @@ class RegressionMetrics(proto.Message):
median_absolute_error (google.protobuf.wrappers_pb2.DoubleValue):
Median absolute error.
r_squared (google.protobuf.wrappers_pb2.DoubleValue):
R^2 score.
R^2 score. This corresponds to r2_score in ML.EVALUATE.
"""

mean_absolute_error = proto.Field(
@@ -528,7 +532,7 @@ class ClusteringMetrics(proto.Message):
Mean of squared distances between each sample
to its cluster centroid.
clusters (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster]):
[Beta] Information for all clusters.
Information for all clusters.
"""

class Cluster(proto.Message):
@@ -697,10 +701,29 @@ class ArimaSingleModelForecastingMetrics(proto.Message):
Is arima model fitted with drift or not. It
is always false when d is not 1.
time_series_id (str):
The id to indicate different time series.
The time_series_id value for this time series. It will be
one of the unique values from the time_series_id_column
specified during ARIMA model training. Only present when
time_series_id_column training option was used.
time_series_ids (Sequence[str]):
The tuple of time_series_ids identifying this time series.
It will be one of the unique tuples of values present in the
time_series_id_columns specified during ARIMA model
training. Only present when time_series_id_columns training
option was used and the order of values here are same as the
order of time_series_id_columns.
seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]):
Seasonal periods. Repeated because multiple
periods are supported for one time series.
has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue):
If true, holiday_effect is a part of time series
decomposition result.
has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
If true, spikes_and_dips is a part of time series
decomposition result.
has_step_changes (google.protobuf.wrappers_pb2.BoolValue):
If true, step_changes is a part of time series decomposition
result.
"""

non_seasonal_order = proto.Field(
@@ -711,9 +734,19 @@ class ArimaSingleModelForecastingMetrics(proto.Message):
)
has_drift = proto.Field(proto.BOOL, number=3,)
time_series_id = proto.Field(proto.STRING, number=4,)
time_series_ids = proto.RepeatedField(proto.STRING, number=9,)
seasonal_periods = proto.RepeatedField(
proto.ENUM, number=5, enum="Model.SeasonalPeriod.SeasonalPeriodType",
)
has_holiday_effect = proto.Field(
proto.MESSAGE, number=6, message=wrappers_pb2.BoolValue,
)
has_spikes_and_dips = proto.Field(
proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue,
)
has_step_changes = proto.Field(
proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue,
)

non_seasonal_order = proto.RepeatedField(
proto.MESSAGE, number=1, message="Model.ArimaOrder",
@@ -901,7 +934,7 @@ class TrainingRun(proto.Message):
"""

class TrainingOptions(proto.Message):
r"""
r"""Options used in model training.
Attributes:
max_iterations (int):
The maximum number of iterations in training.
@@ -972,8 +1005,9 @@ class TrainingOptions(proto.Message):
num_clusters (int):
Number of clusters for clustering models.
model_uri (str):
[Beta] Google Cloud Storage URI from which the model was
imported. Only applicable for imported models.
Google Cloud Storage URI from which the model
was imported. Only applicable for imported
models.
optimization_strategy (google.cloud.bigquery_v2.types.Model.OptimizationStrategy):
Optimization strategy for training linear
regression models.
@@ -1030,8 +1064,11 @@ class TrainingOptions(proto.Message):
If a valid value is specified, then holiday
effects modeling is enabled.
time_series_id_column (str):
The id column that will be used to indicate
different time series to forecast in parallel.
The time series id column that was used
during ARIMA model training.
time_series_id_columns (Sequence[str]):
The time series id columns that were used
during ARIMA model training.
horizon (int):
The number of periods ahead that need to be
forecasted.
@@ -1042,6 +1079,15 @@ class TrainingOptions(proto.Message):
output feature name is A.b.
auto_arima_max_order (int):
The max value of non-seasonal p and q.
decompose_time_series (google.protobuf.wrappers_pb2.BoolValue):
If true, perform decompose time series and
save the results.
clean_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
If true, clean spikes and dips in the input
time series.
adjust_step_changes (google.protobuf.wrappers_pb2.BoolValue):
If true, detect step changes and make data
adjustment in the input time series.
"""

max_iterations = proto.Field(proto.INT64, number=1,)
@@ -1120,9 +1166,19 @@ class TrainingOptions(proto.Message):
proto.ENUM, number=42, enum="Model.HolidayRegion",
)
time_series_id_column = proto.Field(proto.STRING, number=43,)
time_series_id_columns = proto.RepeatedField(proto.STRING, number=51,)
horizon = proto.Field(proto.INT64, number=44,)
preserve_input_structs = proto.Field(proto.BOOL, number=45,)
auto_arima_max_order = proto.Field(proto.INT64, number=46,)
decompose_time_series = proto.Field(
proto.MESSAGE, number=50, message=wrappers_pb2.BoolValue,
)
clean_spikes_and_dips = proto.Field(
proto.MESSAGE, number=52, message=wrappers_pb2.BoolValue,
)
adjust_step_changes = proto.Field(
proto.MESSAGE, number=53, message=wrappers_pb2.BoolValue,
)

class IterationResult(proto.Message):
r"""Information about a single iteration of the training run.
@@ -1218,10 +1274,29 @@ class ArimaModelInfo(proto.Message):
Whether Arima model fitted with drift or not.
It is always false when d is not 1.
time_series_id (str):
The id to indicate different time series.
The time_series_id value for this time series. It will be
one of the unique values from the time_series_id_column
specified during ARIMA model training. Only present when
time_series_id_column training option was used.
time_series_ids (Sequence[str]):
The tuple of time_series_ids identifying this time series.
It will be one of the unique tuples of values present in the
time_series_id_columns specified during ARIMA model
training. Only present when time_series_id_columns training
option was used and the order of values here are same as the
order of time_series_id_columns.
seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]):
Seasonal periods. Repeated because multiple
periods are supported for one time series.
has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue):
If true, holiday_effect is a part of time series
decomposition result.
has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
If true, spikes_and_dips is a part of time series
decomposition result.
has_step_changes (google.protobuf.wrappers_pb2.BoolValue):
If true, step_changes is a part of time series decomposition
result.
"""

non_seasonal_order = proto.Field(
@@ -1237,11 +1312,21 @@ class ArimaModelInfo(proto.Message):
)
has_drift = proto.Field(proto.BOOL, number=4,)
time_series_id = proto.Field(proto.STRING, number=5,)
time_series_ids = proto.RepeatedField(proto.STRING, number=10,)
seasonal_periods = proto.RepeatedField(
proto.ENUM,
number=6,
enum="Model.SeasonalPeriod.SeasonalPeriodType",
)
has_holiday_effect = proto.Field(
proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue,
)
has_spikes_and_dips = proto.Field(
proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue,
)
has_step_changes = proto.Field(
proto.MESSAGE, number=9, message=wrappers_pb2.BoolValue,
)

arima_model_info = proto.RepeatedField(
proto.MESSAGE,
@@ -1319,6 +1404,7 @@ class ArimaModelInfo(proto.Message):
label_columns = proto.RepeatedField(
proto.MESSAGE, number=11, message=standard_sql.StandardSqlField,
)
best_trial_id = proto.Field(proto.INT64, number=19,)


class GetModelRequest(proto.Message):
@@ -36,11 +36,23 @@ class TableReference(proto.Message):
maximum length is 1,024 characters. Certain operations allow
suffixing of the table ID with a partition decorator, such
as ``sample_table$20190123``.
project_id_alternative (Sequence[str]):
The alternative field that will be used when ESF is not able
to translate the received data to the project_id field.
dataset_id_alternative (Sequence[str]):
The alternative field that will be used when ESF is not able
to translate the received data to the project_id field.
table_id_alternative (Sequence[str]):
The alternative field that will be used when ESF is not able
to translate the received data to the project_id field.
"""

project_id = proto.Field(proto.STRING, number=1,)
dataset_id = proto.Field(proto.STRING, number=2,)
table_id = proto.Field(proto.STRING, number=3,)
project_id_alternative = proto.RepeatedField(proto.STRING, number=4,)
dataset_id_alternative = proto.RepeatedField(proto.STRING, number=5,)
table_id_alternative = proto.RepeatedField(proto.STRING, number=6,)


__all__ = tuple(sorted(__protobuf__.manifest))

0 comments on commit fe7a902

Please sign in to comment.