Skip to content

Commit

Permalink
feat: Implementation of run partition query (#1080)
Browse files Browse the repository at this point in the history
* feat: Implementation of run partition query

* Comments incorporated

* Comments incorporated

* Comments incorporated
  • Loading branch information
ankiaga committed Jan 24, 2024
1 parent ec87c08 commit f3b23b2
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 28 deletions.
2 changes: 2 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Expand Up @@ -103,6 +103,8 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
return connection.run_partition(
parsed_statement.client_side_statement_params[0]
)
if statement_type == ClientSideStatementType.RUN_PARTITIONED_QUERY:
return connection.run_partitioned_query(parsed_statement)


def _get_streamed_result_set(column_name, type_code, column_values):
Expand Down
27 changes: 17 additions & 10 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Expand Up @@ -35,6 +35,9 @@
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE)
RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE)
RE_RUN_PARTITIONED_QUERY = re.compile(
r"^\s*(RUN)\s+(PARTITIONED)\s+(QUERY)\s+(.+)", re.IGNORECASE
)


def parse_stmt(query):
Expand All @@ -53,25 +56,29 @@ def parse_stmt(query):
client_side_statement_params = []
if RE_COMMIT.match(query):
client_side_statement_type = ClientSideStatementType.COMMIT
if RE_BEGIN.match(query):
client_side_statement_type = ClientSideStatementType.BEGIN
if RE_ROLLBACK.match(query):
elif RE_ROLLBACK.match(query):
client_side_statement_type = ClientSideStatementType.ROLLBACK
if RE_SHOW_COMMIT_TIMESTAMP.match(query):
elif RE_SHOW_COMMIT_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
if RE_SHOW_READ_TIMESTAMP.match(query):
elif RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
if RE_START_BATCH_DML.match(query):
elif RE_START_BATCH_DML.match(query):
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
if RE_RUN_BATCH.match(query):
elif RE_BEGIN.match(query):
client_side_statement_type = ClientSideStatementType.BEGIN
elif RE_RUN_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if RE_ABORT_BATCH.match(query):
elif RE_ABORT_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if RE_PARTITION_QUERY.match(query):
elif RE_RUN_PARTITIONED_QUERY.match(query):
match = re.search(RE_RUN_PARTITIONED_QUERY, query)
client_side_statement_params.append(match.group(4))
client_side_statement_type = ClientSideStatementType.RUN_PARTITIONED_QUERY
elif RE_PARTITION_QUERY.match(query):
match = re.search(RE_PARTITION_QUERY, query)
client_side_statement_params.append(match.group(2))
client_side_statement_type = ClientSideStatementType.PARTITION_QUERY
if RE_RUN_PARTITION.match(query):
elif RE_RUN_PARTITION.match(query):
match = re.search(RE_RUN_PARTITION, query)
client_side_statement_params.append(match.group(3))
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
Expand Down
40 changes: 28 additions & 12 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -511,15 +511,7 @@ def partition_query(
):
statement = parsed_statement.statement
partitioned_query = parsed_statement.client_side_statement_params[0]
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
"Only queries can be partitioned. Invalid statement: " + statement.sql
)
if self.read_only is not True and self._client_transaction_started is True:
raise ProgrammingError(
"Partitioned query not supported as the connection is not in "
"read only mode or ReadWrite transaction started"
)
self._partitioned_query_validation(partitioned_query, statement)

batch_snapshot = self._database.batch_snapshot()
partition_ids = []
Expand All @@ -531,17 +523,18 @@ def partition_query(
query_options=query_options,
)
)

batch_transaction_id = batch_snapshot.get_batch_transaction_id()
for partition in partitions:
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
partition_ids.append(
partition_helper.encode_to_string(batch_transaction_id, partition)
)
return partition_ids

@check_not_closed
def run_partition(self, batch_transaction_id):
def run_partition(self, encoded_partition_id):
partition_id: PartitionId = partition_helper.decode_from_string(
batch_transaction_id
encoded_partition_id
)
batch_transaction_id = partition_id.batch_transaction_id
batch_snapshot = self._database.batch_snapshot(
Expand All @@ -551,6 +544,29 @@ def run_partition(self, batch_transaction_id):
)
return batch_snapshot.process(partition_id.partition_result)

@check_not_closed
def run_partitioned_query(
self,
parsed_statement: ParsedStatement,
):
statement = parsed_statement.statement
partitioned_query = parsed_statement.client_side_statement_params[0]
self._partitioned_query_validation(partitioned_query, statement)
batch_snapshot = self._database.batch_snapshot()
return batch_snapshot.run_partitioned_query(
partitioned_query, statement.params, statement.param_types
)

def _partitioned_query_validation(self, partitioned_query, statement):
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
"Only queries can be partitioned. Invalid statement: " + statement.sql
)
if self.read_only is not True and self._client_transaction_started is True:
raise ProgrammingError(
"Partitioned query is not supported, because the connection is in a read/write transaction."
)

def __enter__(self):
return self

Expand Down
5 changes: 4 additions & 1 deletion google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -49,6 +49,7 @@
from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
from google.cloud.spanner_v1.merged_result_set import MergedResultSet

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])

Expand Down Expand Up @@ -248,7 +249,9 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
self, self._parsed_statement
)
if self._result_set is not None:
if isinstance(self._result_set, StreamedManyResultSets):
if isinstance(
self._result_set, StreamedManyResultSets
) or isinstance(self._result_set, MergedResultSet):
self._itr = self._result_set
else:
self._itr = PeekIterator(self._result_set)
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Expand Up @@ -35,6 +35,7 @@ class ClientSideStatementType(Enum):
ABORT_BATCH = 8
PARTITION_QUERY = 9
RUN_PARTITION = 10
RUN_PARTITIONED_QUERY = 11


@dataclass
Expand Down
72 changes: 67 additions & 5 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -54,6 +54,7 @@
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
from google.cloud.spanner_v1.session import Session
Expand Down Expand Up @@ -1416,11 +1417,6 @@ def generate_query_batches(
(Optional) desired size for each partition generated. The service
uses this as a hint, the actual partition size may differ.
:type partition_size_bytes: int
:param partition_size_bytes:
(Optional) desired size for each partition generated. The service
uses this as a hint, the actual partition size may differ.
:type max_partitions: int
:param max_partitions:
(Optional) desired maximum number of partitions generated. The
Expand Down Expand Up @@ -1513,6 +1509,72 @@ def process_query_batch(
partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout
)

def run_partitioned_query(
self,
sql,
params=None,
param_types=None,
partition_size_bytes=None,
max_partitions=None,
query_options=None,
data_boost_enabled=False,
):
"""Start a partitioned query operation to get list of partitions and
then executes each partition on a separate thread
:type sql: str
:param sql: SQL query statement
:type params: dict, {str -> column value}
:param params: values for parameter replacement. Keys must match
the names used in ``sql``.
:type param_types: dict[str -> Union[dict, .types.Type]]
:param param_types:
(Optional) maps explicit types for one or more param values;
required if parameters are passed.
:type partition_size_bytes: int
:param partition_size_bytes:
(Optional) desired size for each partition generated. The service
uses this as a hint, the actual partition size may differ.
:type max_partitions: int
:param max_partitions:
(Optional) desired maximum number of partitions generated. The
service uses this as a hint, the actual number of partitions may
differ.
:type query_options:
:class:`~google.cloud.spanner_v1.types.ExecuteSqlRequest.QueryOptions`
or :class:`dict`
:param query_options:
(Optional) Query optimizer configuration to use for the given query.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.QueryOptions`
:type data_boost_enabled:
:param data_boost_enabled:
(Optional) If this is for a partitioned query and this field is
set ``true``, the request will be executed using data boost.
Please see https://cloud.google.com/spanner/docs/databoost/databoost-overview
:rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
)
)
return MergedResultSet(self, partitions, 0)

def process(self, batch):
"""Process a single, partitioned query or read.
Expand Down
133 changes: 133 additions & 0 deletions google/cloud/spanner_v1/merged_result_set.py
@@ -0,0 +1,133 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from queue import Queue
from typing import Any, TYPE_CHECKING
from threading import Lock, Event

if TYPE_CHECKING:
from google.cloud.spanner_v1.database import BatchSnapshot

QUEUE_SIZE_PER_WORKER = 32
MAX_PARALLELISM = 16


class PartitionExecutor:
"""
Executor that executes single partition on a separate thread and inserts
rows in the queue
"""

def __init__(self, batch_snapshot, partition_id, merged_result_set):
self._batch_snapshot: BatchSnapshot = batch_snapshot
self._partition_id = partition_id
self._merged_result_set: MergedResultSet = merged_result_set
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue

def run(self):
results = None
try:
results = self._batch_snapshot.process_query_batch(self._partition_id)
for row in results:
if self._merged_result_set._metadata is None:
self._set_metadata(results)
self._queue.put(PartitionExecutorResult(data=row))
# Special case: The result set did not return any rows.
# Push the metadata to the merged result set.
if self._merged_result_set._metadata is None:
self._set_metadata(results)
except Exception as ex:
if self._merged_result_set._metadata is None:
self._set_metadata(results, True)
self._queue.put(PartitionExecutorResult(exception=ex))
finally:
# Emit a special 'is_last' result to ensure that the MergedResultSet
# is not blocked on a queue that never receives any more results.
self._queue.put(PartitionExecutorResult(is_last=True))

def _set_metadata(self, results, is_exception=False):
self._merged_result_set.metadata_lock.acquire()
try:
if not is_exception:
self._merged_result_set._metadata = results.metadata
finally:
self._merged_result_set.metadata_lock.release()
self._merged_result_set.metadata_event.set()


@dataclass
class PartitionExecutorResult:
data: Any = None
exception: Exception = None
is_last: bool = False


class MergedResultSet:
"""
Executes multiple partitions on different threads and then combines the
results from multiple queries using a synchronized queue. The order of the
records in the MergedResultSet is not guaranteed.
"""

def __init__(self, batch_snapshot, partition_ids, max_parallelism):
self._exception = None
self._metadata = None
self.metadata_event = Event()
self.metadata_lock = Lock()

partition_ids_count = len(partition_ids)
self._finished_count_down_latch = partition_ids_count
parallelism = min(MAX_PARALLELISM, partition_ids_count)
if max_parallelism != 0:
parallelism = min(partition_ids_count, max_parallelism)
self._queue = Queue(maxsize=QUEUE_SIZE_PER_WORKER * parallelism)

partition_executors = []
for partition_id in partition_ids:
partition_executors.append(
PartitionExecutor(batch_snapshot, partition_id, self)
)
executor = ThreadPoolExecutor(max_workers=parallelism)
for partition_executor in partition_executors:
executor.submit(partition_executor.run)
executor.shutdown(False)

def __iter__(self):
return self

def __next__(self):
if self._exception is not None:
raise self._exception
while True:
partition_result = self._queue.get()
if partition_result.is_last:
self._finished_count_down_latch -= 1
if self._finished_count_down_latch == 0:
raise StopIteration
elif partition_result.exception is not None:
self._exception = partition_result.exception
raise self._exception
else:
return partition_result.data

@property
def metadata(self):
self.metadata_event.wait()
return self._metadata

@property
def stats(self):
# TODO: Implement
return None

0 comments on commit f3b23b2

Please sign in to comment.