Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _log_job_state(self):
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
self._gca_resource.state.name,
)
)

Expand Down Expand Up @@ -1490,7 +1490,7 @@ def iter_outputs(
if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED:
raise RuntimeError(
f"Cannot read outputs until BatchPredictionJob has succeeded, "
f"current state: {self._gca_resource.state}"
f"current state: {self._gca_resource.state.name}"
)

output_info = self._gca_resource.output_info
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def _block_until_complete(self):
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
self._gca_resource.state.name,
)
)
log_wait = min(log_wait * multiplier, max_wait)
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _block_until_complete(self) -> None:
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
self._gca_resource.state.name,
)
)
log_wait = min(log_wait * multiplier, max_wait)
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def _block_until_complete(self):
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
self._gca_resource.state.name,
)
)
log_wait = min(log_wait * _WAIT_TIME_MULTIPLIER, _MAX_WAIT_TIME)
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,21 @@ def test_cancel_mock_job(self, fake_job_cancel_mock):

fake_job_cancel_mock.assert_called_once_with(name=_TEST_JOB_RESOURCE_NAME)

@pytest.mark.usefixtures("fake_job_getter_mock")
def test_log_job_state_uses_symbolic_name(self):
"""_log_job_state must log the enum name, not the integer value (regression for Python 3.11+)."""
fake_job = self.FakeJob(job_name=_TEST_JOB_RESOURCE_NAME)
fake_job._gca_resource = mock.Mock()
fake_job._gca_resource.name = _TEST_JOB_RESOURCE_NAME
fake_job._gca_resource.state = gca_job_state_compat.JobState.JOB_STATE_RUNNING

with mock.patch.object(jobs._LOGGER, "info") as mock_info:
fake_job._log_job_state()

logged_msg = mock_info.call_args[0][0]
assert "JOB_STATE_RUNNING" in logged_msg
assert "current state:\n3" not in logged_msg


@pytest.fixture
def get_batch_prediction_job_mock():
Expand Down Expand Up @@ -695,6 +710,21 @@ def test_batch_prediction_iter_dirs_while_running(self):
)
bp.iter_outputs()

@pytest.mark.usefixtures("get_batch_prediction_job_running_bq_output_mock")
def test_batch_prediction_iter_dirs_while_running_error_uses_symbolic_state_name(
self,
):
"""RuntimeError message must use symbolic state name, not integer (regression for Python 3.11+)."""
with pytest.raises(RuntimeError) as exc_info:
bp = jobs.BatchPredictionJob(
batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME
)
bp.iter_outputs()

error_msg = str(exc_info.value)
assert "JOB_STATE_RUNNING" in error_msg
assert "current state: 3" not in error_msg

@pytest.mark.usefixtures("get_batch_prediction_job_empty_output_mock")
def test_batch_prediction_iter_dirs_invalid_output_info(self):
"""
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/aiplatform/test_pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from google.cloud.aiplatform import (
pipeline_job_schedules,
schedules as aiplatform_schedules,
)
from google.cloud.aiplatform.preview.pipelinejob import (
pipeline_jobs as preview_pipeline_jobs,
Expand Down Expand Up @@ -434,6 +435,47 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_block_until_complete_logs_symbolic_state_name(self):
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
state_sequence = [
gca_schedule.Schedule.State.ACTIVE, # first loop check
gca_schedule.Schedule.State.COMPLETED, # second check exits loop
]
state_index = [0]

def get_state():
s = state_sequence[state_index[0]]
state_index[0] = min(state_index[0] + 1, len(state_sequence) - 1)
return s

mock_schedule = mock.Mock()
type(mock_schedule).state = mock.PropertyMock(side_effect=get_state)

active_gca = gca_schedule.Schedule(
name=_TEST_PIPELINE_JOB_SCHEDULE_NAME,
state=gca_schedule.Schedule.State.ACTIVE,
)
mock_schedule._gca_resource = active_gca

logged_messages = []

# time.time: first call sets previous_time=0; second gives 10 → triggers log (10 >= 5)
time_vals = iter([0.0, 10.0, 20.0])
with mock.patch("google.cloud.aiplatform.schedules.time.time", side_effect=time_vals), \
mock.patch("google.cloud.aiplatform.schedules.time.sleep"), \
mock.patch.object(
aiplatform_schedules._LOGGER, "info",
side_effect=lambda msg, *a, **kw: logged_messages.append(msg)
):
aiplatform_schedules._Schedule._block_until_complete(mock_schedule)

state_log = next(
(m for m in logged_messages if "current state" in m), None
)
assert state_log is not None, "No 'current state' log message found"
assert "ACTIVE" in state_log
assert "current state:\n1" not in state_log

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,45 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@mock.patch.object(pipeline_jobs, "_JOB_WAIT_TIME", 0)
@mock.patch.object(pipeline_jobs, "_LOG_WAIT_TIME", 0)
def test_block_until_complete_logs_symbolic_state_name(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
):
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

logged_messages = []

with patch.object(storage.Blob, "download_as_bytes") as mock_load, \
mock.patch.object(
pipeline_jobs._LOGGER, "info",
side_effect=lambda msg, *a, **kw: logged_messages.append(msg)
):
mock_load.return_value = _TEST_PIPELINE_SPEC_JSON.encode()

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)
job.run(sync=True, create_request_timeout=None)

state_log = next(
(m for m in logged_messages if "current state" in m), None
)
assert state_log is not None, "No 'current state' log message found"
assert "PIPELINE_STATE_RUNNING" in state_log
assert "current state:\n3" not in state_log

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,59 @@ def teardown_method(self):
pathlib.Path(self._local_script_file_name).unlink()
initializer.global_pool.shutdown(wait=True)

def test_block_until_complete_logs_symbolic_state_name(
self, mock_model_service_get
):
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME)

logged_messages = []

with mock.patch.object(
pipeline_service_client.PipelineServiceClient, "create_training_pipeline"
) as mock_create, mock.patch.object(
source_utils._TrainingScriptPythonPackager, "package_and_copy_to_gcs"
) as mock_pkg, mock.patch.object(
pipeline_service_client.PipelineServiceClient, "get_training_pipeline"
) as mock_get, mock.patch.object(
training_jobs, "_LOG_WAIT_TIME", 0
), mock.patch.object(
training_jobs, "_JOB_WAIT_TIME", 0
), mock.patch.object(
training_jobs._LOGGER, "info", side_effect=lambda msg, *a, **kw: logged_messages.append(msg)
):
mock_pkg.return_value = _TEST_OUTPUT_PYTHON_PACKAGE_PATH
mock_create.return_value = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME),
)
_running = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
training_task_inputs={},
)
_succeeded = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
training_task_inputs={},
model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME),
)
mock_get.side_effect = [_running, _running] + [_succeeded] * 8
job = training_jobs.CustomTrainingJob(
display_name=_TEST_DISPLAY_NAME,
script_path=self._local_script_file_name,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
)
job.run(base_output_dir=_TEST_BASE_OUTPUT_DIR, sync=True)

state_log = next(
(m for m in logged_messages if "current state" in m), None
)
assert state_log is not None, "No 'current state' log message found"
assert "PIPELINE_STATE_RUNNING" in state_log
assert "current state:\n3" not in state_log

@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
Expand Down