Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions google/cloud/aiplatform/vertex_ray/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,30 @@ def write_bigquery(
By default, concurrency is dynamically decided based on the available
resources.
"""
if ray.__version__ == "2.4.0":
version = ray.__version__

if version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)

elif ray.__version__ in ("2.9.3", "2.33.0", "2.42.0", "2.47.1"):
if ray.__version__ == "2.9.3":
elif version in ("2.9.3", "2.33.0", "2.42.0", "2.47.1"):
if version == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
if ray_remote_args is None:
ray_remote_args = {}

# Each write task will launch individual remote tasks to write each block
# To avoid duplicate block writes, the write task should not be retried
if ray_remote_args.get("max_retries", 0) != 0:
print(
"[Ray on Vertex AI]: The max_retries of a BigQuery Write "
"Task should be set to 0 to avoid duplicate writes."
)
# Avoid dict modification if not needed; assignment needed only when input is None
ray_remote_args = {} if ray_remote_args is None else ray_remote_args

max_retries = ray_remote_args.get("max_retries")
if max_retries is not None:
if max_retries != 0:
print(
"[Ray on Vertex AI]: The max_retries of a BigQuery Write "
"Task should be set to 0 to avoid duplicate writes."
)
else:
# Only assign if it wasn't present in the input mapping
ray_remote_args["max_retries"] = 0

if ray.__version__ == "2.9.3":
# Avoid lookups, also, no need to re-check version set membership
if version == "2.9.3":
# Concurrency and overwrite_table are not supported in 2.9.3
datasink = _BigQueryDatasink(
project_id=project_id,
Expand All @@ -174,7 +178,7 @@ def write_bigquery(
datasink=datasink,
ray_remote_args=ray_remote_args,
)
elif ray.__version__ in ("2.33.0", "2.42.0", "2.47.1"):
else:
datasink = _BigQueryDatasink(
project_id=project_id,
dataset=dataset,
Expand All @@ -188,6 +192,6 @@ def write_bigquery(
)
else:
raise ImportError(
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
f"[Ray on Vertex AI]: Unsupported version {version}."
+ "Only 2.47.1, 2.42.0, 2.33.0 and 2.9.3 are supported."
)