-
Notifications
You must be signed in to change notification settings - Fork 408
feat: Added AutoMLForecastingTrainingJob and tests #237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
989eac5
feat: Added AutoMLForecastingTrainingJob and tests (#1)
60b3113
Merge branch 'dev' of https://github.com/thehardikv/python-aiplatform…
hardik-vala 68d7c19
Rename 'forecasting_task' -> 'automl_forecasting'
hardik-vala d9cc97f
Update google/cloud/aiplatform/training_jobs.py
bb6c78c
Update google/cloud/aiplatform/training_jobs.py
19654ae
Update google/cloud/aiplatform/training_jobs.py
83797bb
Make static_columns Optional + Fix lint problems
hardik-vala 37a7949
Merge remote-tracking branch 'upstream/dev' into dev
hardik-vala ce0eb99
Merge remote-tracking branch 'upstream/dev' into dev
hardik-vala d66a978
Add _supported_training_schemas to AutoMLForecastingTrainingJob
hardik-vala ff97f60
Merge remote-tracking branch 'upstream/dev' into dev
hardik-vala 9c8a893
Merge remote-tracking branch 'upstream/dev' into dev
hardik-vala 0c3e5a9
Specify correct values for predefined_split_column_name doc
hardik-vala 5058a3a
Merge remote-tracking branch 'upstream/dev' into dev
hardik-vala ec88e9d
Merge remote-tracking branch 'upstream/dev' into dev
hardik-vala 8c4852b
Created TimeSeriesDataset
hardik-vala 08de245
Fix forecasting training unit tests
hardik-vala 78d6756
Rename uCAIP Forecasting training API fields to align with UI and API…
47fea20
Lint
hardik-vala 5077d11
Merge branch 'dev' of https://github.com/thehardikv/python-aiplatform…
hardik-vala 5436c52
Lint
hardik-vala 0157cc5
Merge remote-tracking branch 'upstream/dev' into dev
hardik-vala cbfeae4
Update quantile loss objective doc
hardik-vala b0e42b3
Expose TimeSeriesDataset at root module level
hardik-vala File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,7 @@ | |
| CustomContainerTrainingJob, | ||
| CustomPythonPackageTrainingJob, | ||
| AutoMLTabularTrainingJob, | ||
| AutoMLForecastingTrainingJob, | ||
| AutoMLImageTrainingJob, | ||
| AutoMLTextTrainingJob, | ||
| AutoMLVideoTrainingJob, | ||
|
|
@@ -52,6 +53,7 @@ | |
| "init", | ||
| "AutoMLImageTrainingJob", | ||
| "AutoMLTabularTrainingJob", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TimeSeriesDataset should also be added here. |
||
| "AutoMLForecastingTrainingJob", | ||
| "AutoMLTextTrainingJob", | ||
| "AutoMLVideoTrainingJob", | ||
| "BatchPredictionJob", | ||
|
|
@@ -63,5 +65,6 @@ | |
| "Model", | ||
| "TabularDataset", | ||
| "TextDataset", | ||
| "TimeSeriesDataset", | ||
| "VideoDataset", | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -224,6 +224,11 @@ def create_datasource( | |
| raise ValueError("tabular dataset does not support data import.") | ||
| return TabularDatasource(gcs_source, bq_source) | ||
|
|
||
| if metadata_schema_uri == schema.dataset.metadata.time_series: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM |
||
| if import_schema_uri: | ||
| raise ValueError("time series dataset does not support data import.") | ||
| return TabularDatasource(gcs_source, bq_source) | ||
|
|
||
| if not import_schema_uri and not gcs_source: | ||
| return NonTabularDatasource() | ||
| elif import_schema_uri and gcs_source: | ||
|
|
||
134 changes: 134 additions & 0 deletions
134
google/cloud/aiplatform/datasets/time_series_dataset.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| # -*- coding: utf-8 -*- | ||
|
|
||
| # Copyright 2020 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from typing import Optional, Sequence, Tuple, Union | ||
|
|
||
| from google.auth import credentials as auth_credentials | ||
|
|
||
| from google.cloud.aiplatform import datasets | ||
| from google.cloud.aiplatform.datasets import _datasources | ||
| from google.cloud.aiplatform import initializer | ||
| from google.cloud.aiplatform import schema | ||
| from google.cloud.aiplatform import utils | ||
|
|
||
|
|
||
| class TimeSeriesDataset(datasets._Dataset): | ||
| """Managed time series dataset resource for AI Platform""" | ||
|
|
||
| _supported_metadata_schema_uris: Optional[Tuple[str]] = ( | ||
| schema.dataset.metadata.time_series, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def create( | ||
| cls, | ||
| display_name: str, | ||
| gcs_source: Optional[Union[str, Sequence[str]]] = None, | ||
| bq_source: Optional[str] = None, | ||
| project: Optional[str] = None, | ||
| location: Optional[str] = None, | ||
| credentials: Optional[auth_credentials.Credentials] = None, | ||
| request_metadata: Optional[Sequence[Tuple[str, str]]] = (), | ||
| encryption_spec_key_name: Optional[str] = None, | ||
| sync: bool = True, | ||
| ) -> "TimeSeriesDataset": | ||
| """Creates a new tabular dataset. | ||
|
|
||
| Args: | ||
| display_name (str): | ||
| Required. The user-defined name of the Dataset. | ||
| The name can be up to 128 characters long and can be consist | ||
| of any UTF-8 characters. | ||
| gcs_source (Union[str, Sequence[str]]): | ||
| Google Cloud Storage URI(-s) to the | ||
| input file(s). May contain wildcards. For more | ||
| information on wildcards, see | ||
| https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. | ||
| examples: | ||
| str: "gs://bucket/file.csv" | ||
| Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] | ||
| bq_source (str): | ||
| BigQuery URI to the input table. | ||
| example: | ||
| "bq://project.dataset.table_name" | ||
| project (str): | ||
| Project to upload this model to. Overrides project set in | ||
| aiplatform.init. | ||
| location (str): | ||
| Location to upload this model to. Overrides location set in | ||
| aiplatform.init. | ||
| credentials (auth_credentials.Credentials): | ||
| Custom credentials to use to upload this model. Overrides | ||
| credentials set in aiplatform.init. | ||
| request_metadata (Sequence[Tuple[str, str]]): | ||
| Strings which should be sent along with the request as metadata. | ||
| encryption_spec_key_name (Optional[str]): | ||
| Optional. The Cloud KMS resource identifier of the customer | ||
| managed encryption key used to protect the dataset. Has the | ||
| form: | ||
| ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. | ||
| The key needs to be in the same region as where the compute | ||
| resource is created. | ||
|
|
||
| If set, this Dataset and all sub-resources of this Dataset will be secured by this key. | ||
|
|
||
| Overrides encryption_spec_key_name set in aiplatform.init. | ||
| sync (bool): | ||
| Whether to execute this method synchronously. If False, this method | ||
| will be executed in concurrent Future and any downstream object will | ||
| be immediately returned and synced when the Future has completed. | ||
|
|
||
| Returns: | ||
| time_series_dataset (TimeSeriesDataset): | ||
| Instantiated representation of the managed time series dataset resource. | ||
|
|
||
| """ | ||
|
|
||
| utils.validate_display_name(display_name) | ||
|
|
||
| api_client = cls._instantiate_client(location=location, credentials=credentials) | ||
|
|
||
| metadata_schema_uri = schema.dataset.metadata.time_series | ||
|
|
||
| datasource = _datasources.create_datasource( | ||
| metadata_schema_uri=metadata_schema_uri, | ||
| gcs_source=gcs_source, | ||
| bq_source=bq_source, | ||
| ) | ||
|
|
||
| return cls._create_and_import( | ||
| api_client=api_client, | ||
| parent=initializer.global_config.common_location_path( | ||
| project=project, location=location | ||
| ), | ||
| display_name=display_name, | ||
| metadata_schema_uri=metadata_schema_uri, | ||
| datasource=datasource, | ||
| project=project or initializer.global_config.project, | ||
| location=location or initializer.global_config.location, | ||
| credentials=credentials or initializer.global_config.credentials, | ||
| request_metadata=request_metadata, | ||
| encryption_spec=initializer.global_config.get_encryption_spec( | ||
| encryption_spec_key_name=encryption_spec_key_name | ||
| ), | ||
| sync=sync, | ||
| ) | ||
|
|
||
| def import_data(self): | ||
| raise NotImplementedError( | ||
| f"{self.__class__.__name__} class does not support 'import_data'" | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.