diff --git a/google/cloud/aiplatform/vertex_ray/data.py b/google/cloud/aiplatform/vertex_ray/data.py index 217eb52106..09d894e4af 100644 --- a/google/cloud/aiplatform/vertex_ray/data.py +++ b/google/cloud/aiplatform/vertex_ray/data.py @@ -144,26 +144,30 @@ def write_bigquery( By default, concurrency is dynamically decided based on the available resources. """ - if ray.__version__ == "2.4.0": + version = ray.__version__ + + if version == "2.4.0": raise RuntimeError(_V2_4_WARNING_MESSAGE) - elif ray.__version__ in ("2.9.3", "2.33.0", "2.42.0", "2.47.1"): - if ray.__version__ == "2.9.3": + elif version in ("2.9.3", "2.33.0", "2.42.0", "2.47.1"): + if version == "2.9.3": warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1) - if ray_remote_args is None: - ray_remote_args = {} - - # Each write task will launch individual remote tasks to write each block - # To avoid duplicate block writes, the write task should not be retried - if ray_remote_args.get("max_retries", 0) != 0: - print( - "[Ray on Vertex AI]: The max_retries of a BigQuery Write " - "Task should be set to 0 to avoid duplicate writes." - ) + # Avoid dict modification if not needed; assignment needed only when input is None + ray_remote_args = {} if ray_remote_args is None else ray_remote_args + + max_retries = ray_remote_args.get("max_retries") + if max_retries is not None: + if max_retries != 0: + print( + "[Ray on Vertex AI]: The max_retries of a BigQuery Write " + "Task should be set to 0 to avoid duplicate writes." + ) else: + # Only assign if it wasn't present in the input mapping ray_remote_args["max_retries"] = 0 - if ray.__version__ == "2.9.3": + # Avoid lookups, also, no need to re-check version set membership + if version == "2.9.3": # Concurrency and overwrite_table are not supported in 2.9.3 datasink = _BigQueryDatasink( project_id=project_id, @@ -174,7 +178,7 @@ def write_bigquery( datasink=datasink, ray_remote_args=ray_remote_args, ) - elif ray.__version__ in ("2.33.0", "2.42.0", "2.47.1"): + else: datasink = _BigQueryDatasink( project_id=project_id, dataset=dataset, @@ -188,6 +192,6 @@ def write_bigquery( ) else: raise ImportError( - f"[Ray on Vertex AI]: Unsupported version {ray.__version__}." + f"[Ray on Vertex AI]: Unsupported version {version}." + "Only 2.47.1, 2.42.0, 2.33.0 and 2.9.3 are supported." )