Skip to content

Commit

Permalink
feat: Adds the temporal fusion transformer (TFT) forecasting job
Browse files Browse the repository at this point in the history
COPYBARA_INTEGRATE_REVIEW=#1817 from mikelawrence-google:mikealawrence-add-tft-model-support dde8ac0
PiperOrigin-RevId: 494251134
  • Loading branch information
Mlawrence95 authored and Copybara-Service committed Dec 9, 2022
1 parent 43468bd commit 99313e0
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 86 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
AutoMLTabularTrainingJob,
AutoMLForecastingTrainingJob,
SequenceToSequencePlusForecastingTrainingJob,
TemporalFusionTransformerForecastingTrainingJob,
AutoMLImageTrainingJob,
AutoMLTextTrainingJob,
AutoMLVideoTrainingJob,
Expand Down Expand Up @@ -162,6 +163,7 @@
"TensorboardRun",
"TensorboardTimeSeries",
"TextDataset",
"TemporalFusionTransformerForecastingTrainingJob",
"TimeSeriesDataset",
"VideoDataset",
)
1 change: 1 addition & 0 deletions google/cloud/aiplatform/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class definition:
automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml"
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
tft_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/temporal_fusion_transformer_time_series_forecasting_1.0.0.yaml"
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5204,19 +5204,31 @@ class column_data_types:


class AutoMLForecastingTrainingJob(_ForecastingTrainingJob):
"""Class to train AutoML forecasting models."""

_model_type = "AutoML"
_training_task_definition = schema.training_job.definition.automl_forecasting
_supported_training_schemas = (schema.training_job.definition.automl_forecasting,)


class SequenceToSequencePlusForecastingTrainingJob(_ForecastingTrainingJob):
"""Class to train Sequence to Sequence (Seq2Seq) forecasting models."""

_model_type = "Seq2Seq"
_training_task_definition = schema.training_job.definition.seq2seq_plus_forecasting
_supported_training_schemas = (
schema.training_job.definition.seq2seq_plus_forecasting,
)


class TemporalFusionTransformerForecastingTrainingJob(_ForecastingTrainingJob):
"""Class to train Temporal Fusion Transformer (TFT) forecasting models."""

_model_type = "TFT"
_training_task_definition = schema.training_job.definition.tft_forecasting
_supported_training_schemas = (schema.training_job.definition.tft_forecasting,)


class AutoMLImageTrainingJob(_TrainingJob):
_supported_training_schemas = (
schema.training_job.definition.automl_image_classification,
Expand Down
5 changes: 3 additions & 2 deletions tests/system/aiplatform/test_e2e_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
pytest.param(
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
marks=pytest.mark.skip(reason="Seq2Seq not yet released."),
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
marks=pytest.mark.skip(reason="TFT not yet released."),
),
],
)
Expand Down
102 changes: 18 additions & 84 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@
_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split"
_TEST_SPLIT_TIMESTAMP_COLUMN_NAME = "timestamp"

_FORECASTING_JOB_MODEL_TYPES = [
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
]


@pytest.fixture
def mock_pipeline_service_create():
Expand Down Expand Up @@ -293,13 +299,7 @@ def teardown_method(self):
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_run_call_pipeline_service_create(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -401,13 +401,7 @@ def test_run_call_pipeline_service_create(
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_run_call_pipeline_service_create_with_timeout(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -496,13 +490,7 @@ def test_run_call_pipeline_service_create_with_timeout(
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.usefixtures("mock_pipeline_service_get")
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -584,13 +572,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.usefixtures("mock_pipeline_service_get")
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_run_call_pipeline_if_set_additional_experiments(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -675,13 +657,7 @@ def test_run_call_pipeline_if_set_additional_experiments(
"mock_model_service_get",
)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_run_called_twice_raises(
self,
mock_dataset_time_series,
Expand Down Expand Up @@ -762,13 +738,7 @@ def test_run_called_twice_raises(
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_run_raises_if_pipeline_fails(
self,
mock_pipeline_service_create_and_get_with_fail,
Expand Down Expand Up @@ -823,13 +793,7 @@ def test_run_raises_if_pipeline_fails(
with pytest.raises(RuntimeError):
job.get_model()

@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_raises_before_run_is_called(
self,
mock_pipeline_service_create,
Expand All @@ -855,13 +819,7 @@ def test_raises_before_run_is_called(
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_splits_fraction(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -960,13 +918,7 @@ def test_splits_fraction(
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_splits_timestamp(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1067,13 +1019,7 @@ def test_splits_timestamp(
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_splits_predefined(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1168,13 +1114,7 @@ def test_splits_predefined(
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_splits_default(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1264,13 +1204,7 @@ def test_splits_default(
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.usefixtures("mock_pipeline_service_get")
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"training_job",
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
],
)
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference(
self,
mock_pipeline_service_create,
Expand Down

0 comments on commit 99313e0

Please sign in to comment.