Skip to content

Commit

Permalink
feat: Implementation for partitioned query in dbapi (#1067)
Browse files Browse the repository at this point in the history
* feat: Implementation for partitioned query in dbapi

* Comments incorporated and added more tests

* Small fix

* Test fix

* Removing ClientSideStatementParamKey enum

* Comments incorporated
  • Loading branch information
ankiaga committed Jan 10, 2024
1 parent c4210b2 commit 63daa8a
Show file tree
Hide file tree
Showing 11 changed files with 324 additions and 34 deletions.
43 changes: 30 additions & 13 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Expand Up @@ -50,6 +50,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
:param parsed_statement: parsed_statement based on the sql query
"""
connection = cursor.connection
column_values = []
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
Expand All @@ -63,24 +64,26 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
connection.rollback()
return None
if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP:
if connection._transaction is None:
committed_timestamp = None
else:
committed_timestamp = connection._transaction.committed
if (
connection._transaction is not None
and connection._transaction.committed is not None
):
column_values.append(connection._transaction.committed)
return _get_streamed_result_set(
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
TypeCode.TIMESTAMP,
committed_timestamp,
column_values,
)
if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP:
if connection._snapshot is None:
read_timestamp = None
else:
read_timestamp = connection._snapshot._transaction_read_timestamp
if (
connection._snapshot is not None
and connection._snapshot._transaction_read_timestamp is not None
):
column_values.append(connection._snapshot._transaction_read_timestamp)
return _get_streamed_result_set(
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
TypeCode.TIMESTAMP,
read_timestamp,
column_values,
)
if statement_type == ClientSideStatementType.START_BATCH_DML:
connection.start_batch_dml(cursor)
Expand All @@ -89,14 +92,28 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
return connection.run_batch()
if statement_type == ClientSideStatementType.ABORT_BATCH:
return connection.abort_batch()
if statement_type == ClientSideStatementType.PARTITION_QUERY:
partition_ids = connection.partition_query(parsed_statement)
return _get_streamed_result_set(
"PARTITION",
TypeCode.STRING,
partition_ids,
)
if statement_type == ClientSideStatementType.RUN_PARTITION:
return connection.run_partition(
parsed_statement.client_side_statement_params[0]
)


def _get_streamed_result_set(column_name, type_code, column_value):
def _get_streamed_result_set(column_name, type_code, column_values):
struct_type_pb = StructType(
fields=[StructType.Field(name=column_name, type_=Type(code=type_code))]
)

result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
if column_value is not None:
result_set.values.extend([_make_value_pb(column_value)])
if len(column_values) > 0:
column_values_pb = []
for column_value in column_values:
column_values_pb.append(_make_value_pb(column_value))
result_set.values.extend(column_values_pb)
return StreamedResultSet(iter([result_set]))
16 changes: 15 additions & 1 deletion google/cloud/spanner_dbapi/client_side_statement_parser.py
Expand Up @@ -33,6 +33,8 @@
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE)
RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -48,6 +50,7 @@ def parse_stmt(query):
:returns: ParsedStatement object.
"""
client_side_statement_type = None
client_side_statement_params = []
if RE_COMMIT.match(query):
client_side_statement_type = ClientSideStatementType.COMMIT
if RE_BEGIN.match(query):
Expand All @@ -64,8 +67,19 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if RE_ABORT_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if RE_PARTITION_QUERY.match(query):
match = re.search(RE_PARTITION_QUERY, query)
client_side_statement_params.append(match.group(2))
client_side_statement_type = ClientSideStatementType.PARTITION_QUERY
if RE_RUN_PARTITION.match(query):
match = re.search(RE_RUN_PARTITION, query)
client_side_statement_params.append(match.group(3))
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
StatementType.CLIENT_SIDE,
Statement(query),
client_side_statement_type,
client_side_statement_params,
)
return None
57 changes: 56 additions & 1 deletion google/cloud/spanner_dbapi/connection.py
Expand Up @@ -19,8 +19,15 @@
from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi import partition_helper
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
Statement,
StatementType,
)
from google.cloud.spanner_dbapi.partition_helper import PartitionId
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
Expand Down Expand Up @@ -585,6 +592,54 @@ def abort_batch(self):
self._batch_dml_executor = None
self._batch_mode = BatchMode.NONE

@check_not_closed
def partition_query(
self,
parsed_statement: ParsedStatement,
query_options=None,
):
statement = parsed_statement.statement
partitioned_query = parsed_statement.client_side_statement_params[0]
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
"Only queries can be partitioned. Invalid statement: " + statement.sql
)
if self.read_only is not True and self._client_transaction_started is True:
raise ProgrammingError(
"Partitioned query not supported as the connection is not in "
"read only mode or ReadWrite transaction started"
)

batch_snapshot = self._database.batch_snapshot()
partition_ids = []
partitions = list(
batch_snapshot.generate_query_batches(
partitioned_query,
statement.params,
statement.param_types,
query_options=query_options,
)
)
for partition in partitions:
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
partition_ids.append(
partition_helper.encode_to_string(batch_transaction_id, partition)
)
return partition_ids

@check_not_closed
def run_partition(self, batch_transaction_id):
partition_id: PartitionId = partition_helper.decode_from_string(
batch_transaction_id
)
batch_transaction_id = partition_id.batch_transaction_id
batch_snapshot = self._database.batch_snapshot(
read_timestamp=batch_transaction_id.read_timestamp,
session_id=batch_transaction_id.session_id,
transaction_id=batch_transaction_id.transaction_id,
)
return batch_snapshot.process(partition_id.partition_result)

def __enter__(self):
return self

Expand Down
16 changes: 10 additions & 6 deletions google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -232,19 +232,23 @@ def classify_statement(query, args=None):
get_param_types(args or None),
ResultsChecksum(),
)
if RE_DDL.match(query):
return ParsedStatement(StatementType.DDL, statement)
statement_type = _get_statement_type(statement)
return ParsedStatement(statement_type, statement)

if RE_IS_INSERT.match(query):
return ParsedStatement(StatementType.INSERT, statement)

def _get_statement_type(statement):
query = statement.sql
if RE_DDL.match(query):
return StatementType.DDL
if RE_IS_INSERT.match(query):
return StatementType.INSERT
if RE_NON_UPDATE.match(query) or RE_WITH.match(query):
# As of 13-March-2020, Cloud Spanner only supports WITH for DQL
# statements and doesn't yet support WITH for DML statements.
return ParsedStatement(StatementType.QUERY, statement)
return StatementType.QUERY

statement.sql = ensure_where_clause(query)
return ParsedStatement(StatementType.UPDATE, statement)
return StatementType.UPDATE


def sql_pyformat_args_to_spanner(sql, params):
Expand Down
7 changes: 5 additions & 2 deletions google/cloud/spanner_dbapi/parsed_statement.py
@@ -1,4 +1,4 @@
# Copyright 20203 Google LLC All rights reserved.
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import Any, List

from google.cloud.spanner_dbapi.checksum import ResultsChecksum

Expand All @@ -35,6 +35,8 @@ class ClientSideStatementType(Enum):
START_BATCH_DML = 6
RUN_BATCH = 7
ABORT_BATCH = 8
PARTITION_QUERY = 9
RUN_PARTITION = 10


@dataclass
Expand All @@ -53,3 +55,4 @@ class ParsedStatement:
statement_type: StatementType
statement: Statement
client_side_statement_type: ClientSideStatementType = None
client_side_statement_params: List[Any] = None
46 changes: 46 additions & 0 deletions google/cloud/spanner_dbapi/partition_helper.py
@@ -0,0 +1,46 @@
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any

import gzip
import pickle
import base64


def decode_from_string(encoded_partition_id):
gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8"))
partition_id_bytes = gzip.decompress(gzip_bytes)
return pickle.loads(partition_id_bytes)


def encode_to_string(batch_transaction_id, partition_result):
partition_id = PartitionId(batch_transaction_id, partition_result)
partition_id_bytes = pickle.dumps(partition_id)
gzip_bytes = gzip.compress(partition_id_bytes)
return str(base64.b64encode(gzip_bytes), "utf-8")


@dataclass
class BatchTransactionId:
transaction_id: str
session_id: str
read_timestamp: Any


@dataclass
class PartitionId:
batch_transaction_id: BatchTransactionId
partition_result: Any

0 comments on commit 63daa8a

Please sign in to comment.