Skip to content

Commit

Permalink
fix: Add deprecation warnings when using Ray v2.4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629468075
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Apr 30, 2024
1 parent 3f037a1 commit 3a36784
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 1 deletion.
6 changes: 5 additions & 1 deletion google/cloud/aiplatform/preview/vertex_ray/cluster_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import time
from typing import Dict, List, Optional
import warnings

from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
Expand Down Expand Up @@ -127,7 +128,10 @@ def create_ray_cluster(
logging.info(
"[Ray on Vertex]: No VPC network configured. It is required for client connection."
)

if ray_version == "2.4":
warnings.warn(
_gapic_utils._V2_4_WARNING_MESSAGE, DeprecationWarning, stacklevel=2
)
local_ray_verion = _validation_utils.get_local_ray_version()
if ray_version != local_ray_verion:
if custom_images is None and head_node_type.custom_image is None:
Expand Down
6 changes: 6 additions & 0 deletions google/cloud/aiplatform/preview/vertex_ray/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ray.data
from ray.data.dataset import Dataset
from typing import Any, Dict, Optional
import warnings

from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasource import (
BigQueryDatasource,
Expand All @@ -30,6 +31,10 @@
except ImportError:
_BigQueryDatasink = None

from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
)


def read_bigquery(
project_id: Optional[str] = None,
Expand All @@ -56,6 +61,7 @@ def write_bigquery(
ray_remote_args: Dict[str, Any] = None,
) -> Any:
if ray.__version__ == "2.4.0":
warnings.warn(_V2_4_WARNING_MESSAGE, DeprecationWarning, stacklevel=2)
return ds.write_datasource(
BigQueryDatasource(),
project_id=project_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import ray.cloudpickle as cpickle
import tempfile
from typing import Optional, TYPE_CHECKING
import warnings

from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
Expand All @@ -33,6 +34,9 @@
from google.cloud.aiplatform.preview.vertex_ray.predict.util import (
predict_utils,
)
from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
)


try:
Expand Down Expand Up @@ -123,6 +127,7 @@ def _get_estimator_from(

ray_version = ray.__version__
if ray_version == "2.4.0":
warnings.warn(_V2_4_WARNING_MESSAGE, DeprecationWarning, stacklevel=2)
if not isinstance(checkpoint, ray_sklearn.SklearnCheckpoint):
raise ValueError(
"[Ray on Vertex AI]: arg checkpoint should be a"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import ray
from typing import Callable, Optional, Union, TYPE_CHECKING
import warnings

from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
Expand All @@ -28,6 +29,9 @@
from google.cloud.aiplatform.preview.vertex_ray.predict.util import (
predict_utils,
)
from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
)


try:
Expand Down Expand Up @@ -141,6 +145,7 @@ def _get_tensorflow_model_from(
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
warnings.warn(_V2_4_WARNING_MESSAGE, DeprecationWarning, stacklevel=2)
if not isinstance(checkpoint, ray_tensorflow.TensorflowCheckpoint):
raise ValueError(
"[Ray on Vertex AI]: arg checkpoint should be a"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
import ray
from ray.air._internal.torch_utils import load_torch_model
import tempfile
from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
)
from google.cloud.aiplatform.utils import gcs_utils
from typing import Optional
import warnings


try:
Expand Down Expand Up @@ -61,6 +65,7 @@ def get_pytorch_model_from(
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
warnings.warn(_V2_4_WARNING_MESSAGE, DeprecationWarning, stacklevel=2)
if not isinstance(checkpoint, ray_torch.TorchCheckpoint):
raise ValueError(
"[Ray on Vertex AI]: arg checkpoint should be a"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ray
import tempfile
from typing import Optional, TYPE_CHECKING
import warnings

from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
Expand All @@ -32,6 +33,9 @@
from google.cloud.aiplatform.preview.vertex_ray.predict.util import (
predict_utils,
)
from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
)


try:
Expand Down Expand Up @@ -133,6 +137,7 @@ def _get_xgboost_model_from(
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
warnings.warn(_V2_4_WARNING_MESSAGE, DeprecationWarning, stacklevel=2)
if not isinstance(checkpoint, ray_xgboost.XGBoostCheckpoint):
raise ValueError(
"[Ray on Vertex AI]: arg checkpoint should be a"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@

SUPPORTED_RAY_VERSIONS = immutabledict({"2.4": "2.4.0", "2.9": "2.9.3"})
SUPPORTED_PY_VERSION = ["3.10"]
_V2_4_WARNING_MESSAGE = (
"After May 30, 2024, using Ray version = 2.4 will result in an error. "
"Please use Ray version = 2.9.3 (default) instead."
)

# Artifact Repository available regions.
_AVAILABLE_REGIONS = ["us", "europe", "asia"]
Expand Down

0 comments on commit 3a36784

Please sign in to comment.