Skip to content

Commit

Permalink
fix: Rollback BigQuery Datasource to use do_write() interface
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577245702
  • Loading branch information
matthew29tang authored and Copybara-Service committed Oct 27, 2023
1 parent 1def3f6 commit dc1b82a
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 53 deletions.
81 changes: 50 additions & 31 deletions google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
import time
from typing import Any, Dict, List, Optional
import uuid
import pyarrow.parquet as pq

from google.api_core import client_info
from google.api_core import exceptions
from google.api_core.gapic_v1 import client_info as v1_client_info
from google.cloud import bigquery
from google.cloud import bigquery_storage
from google.cloud.aiplatform import initializer

from ray.data._internal.execution.interfaces import TaskContext
from google.cloud.bigquery_storage import types
import pyarrow.parquet as pq
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data.block import Block
from ray.data.block import BlockAccessor
from ray.data.block import BlockMetadata
Expand All @@ -50,9 +50,6 @@
gapic_version=_BQS_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQS_GAPIC_VERSION}"
)

MAX_RETRY_CNT = 10
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11


class _BigQueryDatasourceReader(Reader):
def __init__(
Expand All @@ -70,12 +67,12 @@ def __init__(

if query is not None and dataset is not None:
raise ValueError(
"[Ray on Vertex AI]: Query and dataset kwargs cannot both "
+ "be provided (must be mutually exclusive)."
"[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)."
)

def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
def _read_single_partition(stream) -> Block:
# Executed by a worker node
def _read_single_partition(stream, kwargs) -> Block:
client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
reader = client.read_rows(stream.name)
return reader.to_arrow()
Expand All @@ -99,9 +96,9 @@ def _read_single_partition(stream) -> Block:

if parallelism == -1:
parallelism = None
requested_session = bigquery_storage.types.ReadSession(
requested_session = types.ReadSession(
table=table,
data_format=bigquery_storage.types.DataFormat.ARROW,
data_format=types.DataFormat.ARROW,
)
read_session = bqs_client.create_read_session(
parent=f"projects/{self._project_id}",
Expand All @@ -110,9 +107,9 @@ def _read_single_partition(stream) -> Block:
)

read_tasks = []
logging.info(f"Created streams: {len(read_session.streams)}")
print("[Ray on Vertex AI]: Created streams:", len(read_session.streams))
if len(read_session.streams) < parallelism:
logging.info(
print(
"[Ray on Vertex AI]: The number of streams created by the "
+ "BigQuery Storage Read API is less than the requested "
+ "parallelism due to the size of the dataset."
Expand All @@ -128,11 +125,15 @@ def _read_single_partition(stream) -> Block:
exec_stats=None,
)

# Create the read task and pass the no-arg wrapper and metadata in
read_task = ReadTask(
lambda stream=stream: [_read_single_partition(stream)],
metadata,
# Create a no-arg wrapper read function which returns a block
read_single_partition = (
lambda stream=stream, kwargs=self._kwargs: [ # noqa: F731
_read_single_partition(stream, kwargs)
]
)

# Create the read task and pass the wrapper and metadata in
read_task = ReadTask(read_single_partition, metadata)
read_tasks.append(read_task)

return read_tasks
Expand Down Expand Up @@ -167,14 +168,18 @@ class BigQueryDatasource(Datasource):
def create_reader(self, **kwargs) -> Reader:
return _BigQueryDatasourceReader(**kwargs)

def write(
def do_write(
self,
blocks: List[ObjectRef[Block]],
ctx: TaskContext,
metadata: List[BlockMetadata],
ray_remote_args: Optional[Dict[str, Any]],
project_id: Optional[str] = None,
dataset: Optional[str] = None,
) -> WriteResult:
def _write_single_block(block: Block, project_id: str, dataset: str):
) -> List[ObjectRef[WriteResult]]:
def _write_single_block(
block: Block, metadata: BlockMetadata, project_id: str, dataset: str
):
print("[Ray on Vertex AI]: Starting to write", metadata.num_rows, "rows")
block = BlockAccessor.for_block(block).to_arrow()

client = bigquery.Client(project=project_id, client_info=bq_info)
Expand All @@ -187,7 +192,7 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
pq.write_table(block, fp, compression="SNAPPY")

retry_cnt = 0
while retry_cnt < MAX_RETRY_CNT:
while retry_cnt < 10:
with open(fp, "rb") as source_file:
job = client.load_table_from_file(
source_file, dataset, job_config=job_config
Expand All @@ -197,11 +202,12 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
logging.info(job.result())
break
except exceptions.Forbidden as e:
logging.info(
print(
"[Ray on Vertex AI]: Rate limit exceeded... Sleeping to try again"
)
logging.debug(e)
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
time.sleep(11)
print("[Ray on Vertex AI]: Finished writing", metadata.num_rows, "rows")

project_id = project_id or initializer.global_config.project

Expand All @@ -210,21 +216,34 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
"[Ray on Vertex AI]: Dataset is required when writing to BigQuery."
)

if ray_remote_args is None:
ray_remote_args = {}

_write_single_block = cached_remote_fn(_write_single_block).options(
**ray_remote_args
)
write_tasks = []

# Set up datasets to write
client = bigquery.Client(project=project_id, client_info=bq_info)
dataset_id = dataset.split(".", 1)[0]
try:
client.create_dataset(f"{project_id}.{dataset_id}", timeout=30)
logging.info(f"[Ray on Vertex AI]: Created dataset {dataset_id}.")
print("[Ray on Vertex AI]: Created dataset", dataset_id)
except exceptions.Conflict:
logging.info(
f"[Ray on Vertex AI]: Dataset {dataset_id} already exists. "
+ "The table will be overwritten if it already exists."
print(
"[Ray on Vertex AI]: Dataset",
dataset_id,
"already exists. The table will be overwritten if it already exists.",
)

# Delete table if it already exists
client.delete_table(f"{project_id}.{dataset}", not_found_ok=True)

for block in blocks:
_write_single_block(block, project_id, dataset)
return "ok"
print("[Ray on Vertex AI]: Writing", len(blocks), "blocks")
for i in range(len(blocks)):
write_task = _write_single_block.remote(
blocks[i], metadata[i], project_id, dataset
)
write_tasks.append(write_task)
return write_tasks
44 changes: 22 additions & 22 deletions tests/unit/vertex_ray/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from google.cloud.bigquery import job
from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream
import mock
import pyarrow as pa
import pytest
import ray

Expand Down Expand Up @@ -90,6 +89,7 @@ def bq_query_mock(query):
client_mock.query = bq_query_mock

monkeypatch.setattr(bigquery, "Client", client_mock)
client_mock.reset_mock()
return client_mock


Expand All @@ -108,6 +108,7 @@ def bqs_create_read_session(max_stream_count=0, **kwargs):
client_mock.create_read_session = bqs_create_read_session

monkeypatch.setattr(bigquery_storage, "BigQueryReadClient", client_mock)
client_mock.reset_mock()
return client_mock


Expand Down Expand Up @@ -258,16 +259,16 @@ def setup_method(self):
def teardown_method(self):
aiplatform.initializer.global_pool.shutdown(wait=True)

def test_write(self):
def test_do_write(self, ray_remote_function_mock):
bq_ds = bigquery_datasource.BigQueryDatasource()
arr = pa.array([2, 4, 5, 100])
block = pa.Table.from_arrays([arr], names=["data"])
status = bq_ds.write(
blocks=[block],
ctx=None,
write_tasks_list = bq_ds.do_write(
blocks=[1, 2, 3, 4],
metadata=[1, 2, 3, 4],
ray_remote_args={},
project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID,
dataset=_TEST_BQ_DATASET,
)
assert status == "ok"
assert len(write_tasks_list) == 4

def test_do_write_initialized(self, ray_remote_function_mock):
"""If initialized, do_write doesn't need to specify project_id."""
Expand All @@ -276,22 +277,21 @@ def test_do_write_initialized(self, ray_remote_function_mock):
staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI,
)
bq_ds = bigquery_datasource.BigQueryDatasource()
arr = pa.array([2, 4, 5, 100])
block = pa.Table.from_arrays([arr], names=["data"])
status = bq_ds.write(
blocks=[block],
ctx=None,
dataset="existingdataset" + "." + _TEST_BQ_TABLE_ID,
write_tasks_list = bq_ds.do_write(
blocks=[1, 2, 3, 4],
metadata=[1, 2, 3, 4],
ray_remote_args={},
dataset=_TEST_BQ_DATASET,
)
assert status == "ok"
assert len(write_tasks_list) == 4

def test_write_dataset_exists(self, ray_remote_function_mock):
def test_do_write_dataset_exists(self, ray_remote_function_mock):
bq_ds = bigquery_datasource.BigQueryDatasource()
arr = pa.array([2, 4, 5, 100])
block = pa.Table.from_arrays([arr], names=["data"])
status = bq_ds.write(
blocks=[block],
ctx=None,
write_tasks_list = bq_ds.do_write(
blocks=[1, 2, 3, 4],
metadata=[1, 2, 3, 4],
ray_remote_args={},
project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID,
dataset="existingdataset" + "." + _TEST_BQ_TABLE_ID,
)
assert status == "ok"
assert len(write_tasks_list) == 4

0 comments on commit dc1b82a

Please sign in to comment.