Skip to content

Commit

Permalink
fix: fix error when calling update_state() after ExperimentRun.list()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544159034
  • Loading branch information
sararob authored and Copybara-Service committed Jun 28, 2023
1 parent d6476d0 commit cb255ec
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/metadata/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def update(
Custom credentials to use to update this resource. Overrides
credentials set in aiplatform.init.
"""
if not hasattr(self, "_threading_lock"):
self._threading_lock = threading.Lock()

with self._threading_lock:
gca_resource = deepcopy(self._gca_resource)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/aiplatform/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,23 @@ def get_artifact_mock():
yield get_artifact_mock


@pytest.fixture
def get_artifact_mock_with_metadata():
with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock:
get_artifact_mock.return_value = GapicArtifact(
name=_TEST_ARTIFACT_NAME,
display_name=_TEST_ARTIFACT_ID,
schema_title=constants.SYSTEM_METRICS,
schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS],
metadata={
google.cloud.aiplatform.metadata.constants._VERTEX_EXPERIMENT_TRACKING_LABEL: True,
constants.GCP_ARTIFACT_RESOURCE_NAME_KEY: test_constants.TensorboardConstants._TEST_TENSORBOARD_RUN_NAME,
constants._STATE_KEY: gca_execution.Execution.State.RUNNING,
},
)
yield get_artifact_mock


@pytest.fixture
def get_artifact_not_found_mock():
with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock:
Expand Down Expand Up @@ -2026,6 +2043,27 @@ def test_experiment_run_get_logged_custom_jobs(self, get_custom_job_mock):
retry=base._DEFAULT_RETRY,
)

@pytest.mark.usefixtures(
"get_metadata_store_mock",
"get_experiment_mock",
"get_experiment_run_mock",
"get_context_mock",
"list_contexts_mock",
"list_executions_mock",
"get_artifact_mock_with_metadata",
"update_context_mock",
)
def test_update_experiment_run_after_list(
self,
):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

experiment_run_list = aiplatform.ExperimentRun.list(experiment=_TEST_EXPERIMENT)
experiment_run_list[0].update_state(gca_execution.Execution.State.FAILED)


class TestTensorboard:
def test_get_or_create_default_tb_with_existing_default(
Expand Down

0 comments on commit cb255ec

Please sign in to comment.