Skip to content

Commit

Permalink
feat: Add tunable parameters for Model Garden model training to the "…
Browse files Browse the repository at this point in the history
…AutoMLImageTrainingJob" in SDK.

PiperOrigin-RevId: 538999389
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Jun 9, 2023
1 parent b4eba68 commit 50646be
Show file tree
Hide file tree
Showing 5 changed files with 834 additions and 374 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/constants/base.py
Expand Up @@ -79,12 +79,14 @@

MODEL_GARDEN_ICN_MODEL_TYPES = {
"EFFICIENTNET",
"MAXVIT",
"VIT",
"COCA",
}

MODEL_GARDEN_IOD_MODEL_TYPES = {
"SPINENET",
"YOLO",
}

# TODO(b/177079208): Use EPCL Enums for validating Model Types
Expand Down
14 changes: 14 additions & 0 deletions google/cloud/aiplatform/hyperparameter_tuning.py
Expand Up @@ -22,6 +22,20 @@

from google.cloud.aiplatform.compat.types import study as gca_study_compat

SEARCH_ALGORITHM_TO_PROTO_VALUE = {
"random": gca_study_compat.StudySpec.Algorithm.RANDOM_SEARCH,
"grid": gca_study_compat.StudySpec.Algorithm.GRID_SEARCH,
None: gca_study_compat.StudySpec.Algorithm.ALGORITHM_UNSPECIFIED,
}

MEASUREMENT_SELECTION_TO_PROTO_VALUE = {
"best": (gca_study_compat.StudySpec.MeasurementSelectionType.BEST_MEASUREMENT),
"last": (gca_study_compat.StudySpec.MeasurementSelectionType.LAST_MEASUREMENT),
None: (
gca_study_compat.StudySpec.MeasurementSelectionType.MEASUREMENT_SELECTION_TYPE_UNSPECIFIED
),
}

_SCALE_TYPE_MAP = {
"linear": gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LINEAR_SCALE,
"log": gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LOG_SCALE,
Expand Down
18 changes: 4 additions & 14 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -1998,18 +1998,6 @@ def job_spec(self):
return self._gca_resource.job_spec


_SEARCH_ALGORITHM_TO_PROTO_VALUE = {
"random": gca_study_compat.StudySpec.Algorithm.RANDOM_SEARCH,
"grid": gca_study_compat.StudySpec.Algorithm.GRID_SEARCH,
None: gca_study_compat.StudySpec.Algorithm.ALGORITHM_UNSPECIFIED,
}

_MEASUREMENT_SELECTION_TO_PROTO_VALUE = {
"best": gca_study_compat.StudySpec.MeasurementSelectionType.BEST_MEASUREMENT,
"last": gca_study_compat.StudySpec.MeasurementSelectionType.LAST_MEASUREMENT,
}


class HyperparameterTuningJob(_RunnableJob):
"""Vertex AI Hyperparameter Tuning Job."""

Expand Down Expand Up @@ -2215,8 +2203,10 @@ def __init__(
study_spec = gca_study_compat.StudySpec(
metrics=metrics,
parameters=parameters,
algorithm=_SEARCH_ALGORITHM_TO_PROTO_VALUE[search_algorithm],
measurement_selection_type=_MEASUREMENT_SELECTION_TO_PROTO_VALUE[
algorithm=hyperparameter_tuning.SEARCH_ALGORITHM_TO_PROTO_VALUE[
search_algorithm
],
measurement_selection_type=hyperparameter_tuning.MEASUREMENT_SELECTION_TO_PROTO_VALUE[
measurement_selection
],
)
Expand Down

0 comments on commit 50646be

Please sign in to comment.