diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index 06d0d25948..4d3408218c 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -50,6 +50,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): :param parsed_statement: parsed_statement based on the sql query """ connection = cursor.connection + column_values = [] if connection.is_closed: raise ProgrammingError(CONNECTION_CLOSED_ERROR) statement_type = parsed_statement.client_side_statement_type @@ -63,24 +64,26 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): connection.rollback() return None if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP: - if connection._transaction is None: - committed_timestamp = None - else: - committed_timestamp = connection._transaction.committed + if ( + connection._transaction is not None + and connection._transaction.committed is not None + ): + column_values.append(connection._transaction.committed) return _get_streamed_result_set( ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name, TypeCode.TIMESTAMP, - committed_timestamp, + column_values, ) if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP: - if connection._snapshot is None: - read_timestamp = None - else: - read_timestamp = connection._snapshot._transaction_read_timestamp + if ( + connection._snapshot is not None + and connection._snapshot._transaction_read_timestamp is not None + ): + column_values.append(connection._snapshot._transaction_read_timestamp) return _get_streamed_result_set( ClientSideStatementType.SHOW_READ_TIMESTAMP.name, TypeCode.TIMESTAMP, - read_timestamp, + column_values, ) if statement_type == ClientSideStatementType.START_BATCH_DML: connection.start_batch_dml(cursor) @@ -89,14 +92,28 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): return connection.run_batch() if statement_type == ClientSideStatementType.ABORT_BATCH: return connection.abort_batch() + if statement_type == ClientSideStatementType.PARTITION_QUERY: + partition_ids = connection.partition_query(parsed_statement) + return _get_streamed_result_set( + "PARTITION", + TypeCode.STRING, + partition_ids, + ) + if statement_type == ClientSideStatementType.RUN_PARTITION: + return connection.run_partition( + parsed_statement.client_side_statement_params[0] + ) -def _get_streamed_result_set(column_name, type_code, column_value): +def _get_streamed_result_set(column_name, type_code, column_values): struct_type_pb = StructType( fields=[StructType.Field(name=column_name, type_=Type(code=type_code))] ) result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb)) - if column_value is not None: - result_set.values.extend([_make_value_pb(column_value)]) + if len(column_values) > 0: + column_values_pb = [] + for column_value in column_values: + column_values_pb.append(_make_value_pb(column_value)) + result_set.values.extend(column_values_pb) return StreamedResultSet(iter([result_set])) diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index 39970259b2..04a3cc523c 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -33,6 +33,8 @@ RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE) RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE) RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE) +RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE) +RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE) def parse_stmt(query): @@ -48,6 +50,7 @@ def parse_stmt(query): :returns: ParsedStatement object. """ client_side_statement_type = None + client_side_statement_params = [] if RE_COMMIT.match(query): client_side_statement_type = ClientSideStatementType.COMMIT if RE_BEGIN.match(query): @@ -64,8 +67,19 @@ def parse_stmt(query): client_side_statement_type = ClientSideStatementType.RUN_BATCH if RE_ABORT_BATCH.match(query): client_side_statement_type = ClientSideStatementType.ABORT_BATCH + if RE_PARTITION_QUERY.match(query): + match = re.search(RE_PARTITION_QUERY, query) + client_side_statement_params.append(match.group(2)) + client_side_statement_type = ClientSideStatementType.PARTITION_QUERY + if RE_RUN_PARTITION.match(query): + match = re.search(RE_RUN_PARTITION, query) + client_side_statement_params.append(match.group(3)) + client_side_statement_type = ClientSideStatementType.RUN_PARTITION if client_side_statement_type is not None: return ParsedStatement( - StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type + StatementType.CLIENT_SIDE, + Statement(query), + client_side_statement_type, + client_side_statement_params, ) return None diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index e635563587..47680fd550 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -19,8 +19,15 @@ from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner +from google.cloud.spanner_dbapi import partition_helper from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor -from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement +from google.cloud.spanner_dbapi.parse_utils import _get_statement_type +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + Statement, + StatementType, +) +from google.cloud.spanner_dbapi.partition_helper import PartitionId from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot @@ -585,6 +592,54 @@ def abort_batch(self): self._batch_dml_executor = None self._batch_mode = BatchMode.NONE + @check_not_closed + def partition_query( + self, + parsed_statement: ParsedStatement, + query_options=None, + ): + statement = parsed_statement.statement + partitioned_query = parsed_statement.client_side_statement_params[0] + if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY: + raise ProgrammingError( + "Only queries can be partitioned. Invalid statement: " + statement.sql + ) + if self.read_only is not True and self._client_transaction_started is True: + raise ProgrammingError( + "Partitioned query not supported as the connection is not in " + "read only mode or ReadWrite transaction started" + ) + + batch_snapshot = self._database.batch_snapshot() + partition_ids = [] + partitions = list( + batch_snapshot.generate_query_batches( + partitioned_query, + statement.params, + statement.param_types, + query_options=query_options, + ) + ) + for partition in partitions: + batch_transaction_id = batch_snapshot.get_batch_transaction_id() + partition_ids.append( + partition_helper.encode_to_string(batch_transaction_id, partition) + ) + return partition_ids + + @check_not_closed + def run_partition(self, batch_transaction_id): + partition_id: PartitionId = partition_helper.decode_from_string( + batch_transaction_id + ) + batch_transaction_id = partition_id.batch_transaction_id + batch_snapshot = self._database.batch_snapshot( + read_timestamp=batch_transaction_id.read_timestamp, + session_id=batch_transaction_id.session_id, + transaction_id=batch_transaction_id.transaction_id, + ) + return batch_snapshot.process(partition_id.partition_result) + def __enter__(self): return self diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 76ac951e0c..008f21bf93 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -232,19 +232,23 @@ def classify_statement(query, args=None): get_param_types(args or None), ResultsChecksum(), ) - if RE_DDL.match(query): - return ParsedStatement(StatementType.DDL, statement) + statement_type = _get_statement_type(statement) + return ParsedStatement(statement_type, statement) - if RE_IS_INSERT.match(query): - return ParsedStatement(StatementType.INSERT, statement) +def _get_statement_type(statement): + query = statement.sql + if RE_DDL.match(query): + return StatementType.DDL + if RE_IS_INSERT.match(query): + return StatementType.INSERT if RE_NON_UPDATE.match(query) or RE_WITH.match(query): # As of 13-March-2020, Cloud Spanner only supports WITH for DQL # statements and doesn't yet support WITH for DML statements. - return ParsedStatement(StatementType.QUERY, statement) + return StatementType.QUERY statement.sql = ensure_where_clause(query) - return ParsedStatement(StatementType.UPDATE, statement) + return StatementType.UPDATE def sql_pyformat_args_to_spanner(sql, params): diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 4f633c7b10..798f5126c3 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -1,4 +1,4 @@ -# Copyright 20203 Google LLC All rights reserved. +# Copyright 2023 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import Any, List from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -35,6 +35,8 @@ class ClientSideStatementType(Enum): START_BATCH_DML = 6 RUN_BATCH = 7 ABORT_BATCH = 8 + PARTITION_QUERY = 9 + RUN_PARTITION = 10 @dataclass @@ -53,3 +55,4 @@ class ParsedStatement: statement_type: StatementType statement: Statement client_side_statement_type: ClientSideStatementType = None + client_side_statement_params: List[Any] = None diff --git a/google/cloud/spanner_dbapi/partition_helper.py b/google/cloud/spanner_dbapi/partition_helper.py new file mode 100644 index 0000000000..94b396c801 --- /dev/null +++ b/google/cloud/spanner_dbapi/partition_helper.py @@ -0,0 +1,46 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any + +import gzip +import pickle +import base64 + + +def decode_from_string(encoded_partition_id): + gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8")) + partition_id_bytes = gzip.decompress(gzip_bytes) + return pickle.loads(partition_id_bytes) + + +def encode_to_string(batch_transaction_id, partition_result): + partition_id = PartitionId(batch_transaction_id, partition_result) + partition_id_bytes = pickle.dumps(partition_id) + gzip_bytes = gzip.compress(partition_id_bytes) + return str(base64.b64encode(gzip_bytes), "utf-8") + + +@dataclass +class BatchTransactionId: + transaction_id: str + session_id: str + read_timestamp: Any + + +@dataclass +class PartitionId: + batch_transaction_id: BatchTransactionId + partition_result: Any diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index e5f00c8ebd..c8c3b92edc 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -16,6 +16,7 @@ import copy import functools + import grpc import logging import re @@ -39,6 +40,7 @@ from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect +from google.cloud.spanner_dbapi.partition_helper import BatchTransactionId from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import TransactionSelector from google.cloud.spanner_v1 import TransactionOptions @@ -747,7 +749,13 @@ def mutation_groups(self): """ return MutationGroupsCheckout(self) - def batch_snapshot(self, read_timestamp=None, exact_staleness=None): + def batch_snapshot( + self, + read_timestamp=None, + exact_staleness=None, + session_id=None, + transaction_id=None, + ): """Return an object which wraps a batch read / query. :type read_timestamp: :class:`datetime.datetime` @@ -757,11 +765,21 @@ def batch_snapshot(self, read_timestamp=None, exact_staleness=None): :param exact_staleness: Execute all reads at a timestamp that is ``exact_staleness`` old. + :type session_id: str + :param session_id: id of the session used in transaction + + :type transaction_id: str + :param transaction_id: id of the transaction + :rtype: :class:`~google.cloud.spanner_v1.database.BatchSnapshot` :returns: new wrapper """ return BatchSnapshot( - self, read_timestamp=read_timestamp, exact_staleness=exact_staleness + self, + read_timestamp=read_timestamp, + exact_staleness=exact_staleness, + session_id=session_id, + transaction_id=transaction_id, ) def run_in_transaction(self, func, *args, **kw): @@ -1139,10 +1157,19 @@ class BatchSnapshot(object): ``exact_staleness`` old. """ - def __init__(self, database, read_timestamp=None, exact_staleness=None): + def __init__( + self, + database, + read_timestamp=None, + exact_staleness=None, + session_id=None, + transaction_id=None, + ): self._database = database + self._session_id = session_id self._session = None self._snapshot = None + self._transaction_id = transaction_id self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness @@ -1190,7 +1217,10 @@ def _get_session(self): """ if self._session is None: session = self._session = self._database.session() - session.create() + if self._session_id is None: + session.create() + else: + session._session_id = self._session_id return self._session def _get_snapshot(self): @@ -1200,10 +1230,22 @@ def _get_snapshot(self): read_timestamp=self._read_timestamp, exact_staleness=self._exact_staleness, multi_use=True, + transaction_id=self._transaction_id, ) - self._snapshot.begin() + if self._transaction_id is None: + self._snapshot.begin() return self._snapshot + def get_batch_transaction_id(self): + snapshot = self._snapshot + if snapshot is None: + raise ValueError("Read-only transaction not begun") + return BatchTransactionId( + snapshot._transaction_id, + snapshot._session.session_id, + snapshot._read_timestamp, + ) + def read(self, *args, **kw): """Convenience method: perform read operation via snapshot. diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 37bed11d7e..491ff37d4a 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -738,6 +738,7 @@ def __init__( max_staleness=None, exact_staleness=None, multi_use=False, + transaction_id=None, ): super(Snapshot, self).__init__(session) opts = [read_timestamp, min_read_timestamp, max_staleness, exact_staleness] @@ -760,6 +761,7 @@ def __init__( self._max_staleness = max_staleness self._exact_staleness = exact_staleness self._multi_use = multi_use + self._transaction_id = transaction_id def _make_txn_selector(self): """Helper for :meth:`read`.""" diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index fdea0b0d17..18bde6c94d 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -536,6 +536,74 @@ def test_batch_dml_invalid_statements(self): with pytest.raises(OperationalError): self._cursor.execute("run batch") + def test_partitioned_query(self): + """Test partition query works in read-only mode.""" + self._cursor.execute("start batch dml") + for i in range(1, 11): + self._insert_row(i) + self._cursor.execute("run batch") + self._conn.commit() + + self._conn.read_only = True + self._cursor.execute("PARTITION SELECT * FROM contacts") + partition_id_rows = self._cursor.fetchall() + assert len(partition_id_rows) > 0 + + rows = [] + for partition_id_row in partition_id_rows: + self._cursor.execute("RUN PARTITION " + partition_id_row[0]) + rows = rows + self._cursor.fetchall() + assert len(rows) == 10 + self._conn.commit() + + def test_partitioned_query_in_rw_transaction(self): + """Test partition query throws exception when connection is not in + read-only mode and neither in auto-commit mode.""" + + with pytest.raises(ProgrammingError): + self._cursor.execute("PARTITION SELECT * FROM contacts") + + def test_partitioned_query_with_dml_query(self): + """Test partition query throws exception when sql query is a DML query.""" + + self._conn.read_only = True + with pytest.raises(ProgrammingError): + self._cursor.execute( + """ + PARTITION INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1111, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + + def test_partitioned_query_in_autocommit_mode(self): + """Test partition query works when connection is not in read-only mode + but is in auto-commit mode.""" + self._cursor.execute("start batch dml") + for i in range(1, 11): + self._insert_row(i) + self._cursor.execute("run batch") + self._conn.commit() + + self._conn.autocommit = True + self._cursor.execute("PARTITION SELECT * FROM contacts") + partition_id_rows = self._cursor.fetchall() + assert len(partition_id_rows) > 0 + + rows = [] + for partition_id_row in partition_id_rows: + self._cursor.execute("RUN PARTITION " + partition_id_row[0]) + rows = rows + self._cursor.fetchall() + assert len(rows) == 10 + + def test_partitioned_query_with_client_transaction_started(self): + """Test partition query throws exception when connection is not in + read-only mode and transaction started using client side statement.""" + + self._conn.autocommit = True + self._cursor.execute("begin transaction") + with pytest.raises(ProgrammingError): + self._cursor.execute("PARTITION SELECT * FROM contacts") + def _insert_row(self, i): self._cursor.execute( f""" diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 7f179d6d31..de7b9a6dce 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -15,9 +15,15 @@ import sys import unittest -from google.cloud.spanner_dbapi.parsed_statement import StatementType +from google.cloud.spanner_dbapi.parsed_statement import ( + StatementType, + ParsedStatement, + Statement, + ClientSideStatementType, +) from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1 import JsonObject +from google.cloud.spanner_dbapi.parse_utils import classify_statement class TestParseUtils(unittest.TestCase): @@ -25,8 +31,6 @@ class TestParseUtils(unittest.TestCase): skip_message = "Subtests are not supported in Python 2" def test_classify_stmt(self): - from google.cloud.spanner_dbapi.parse_utils import classify_statement - cases = ( ("SELECT 1", StatementType.QUERY), ("SELECT s.SongName FROM Songs AS s", StatementType.QUERY), @@ -71,6 +75,32 @@ def test_classify_stmt(self): for query, want_class in cases: self.assertEqual(classify_statement(query).statement_type, want_class) + def test_partition_query_classify_stmt(self): + parsed_statement = classify_statement( + " PARTITION SELECT s.SongName FROM Songs AS s " + ) + self.assertEqual( + parsed_statement, + ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("PARTITION SELECT s.SongName FROM Songs AS s"), + ClientSideStatementType.PARTITION_QUERY, + ["SELECT s.SongName FROM Songs AS s"], + ), + ) + + def test_run_partition_classify_stmt(self): + parsed_statement = classify_statement(" RUN PARTITION bj2bjb2j2bj2ebbh ") + self.assertEqual( + parsed_statement, + ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("RUN PARTITION bj2bjb2j2bj2ebbh"), + ClientSideStatementType.RUN_PARTITION, + ["bj2bjb2j2bj2ebbh"], + ), + ) + @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner(self): from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 5f563773bc..88e7bf8f66 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -2138,7 +2138,10 @@ def test__get_snapshot_new_wo_staleness(self): snapshot = session.snapshot.return_value = self._make_snapshot() self.assertIs(batch_txn._get_snapshot(), snapshot) session.snapshot.assert_called_once_with( - read_timestamp=None, exact_staleness=None, multi_use=True + read_timestamp=None, + exact_staleness=None, + multi_use=True, + transaction_id=None, ) snapshot.begin.assert_called_once_with() @@ -2150,7 +2153,10 @@ def test__get_snapshot_w_read_timestamp(self): snapshot = session.snapshot.return_value = self._make_snapshot() self.assertIs(batch_txn._get_snapshot(), snapshot) session.snapshot.assert_called_once_with( - read_timestamp=timestamp, exact_staleness=None, multi_use=True + read_timestamp=timestamp, + exact_staleness=None, + multi_use=True, + transaction_id=None, ) snapshot.begin.assert_called_once_with() @@ -2162,7 +2168,10 @@ def test__get_snapshot_w_exact_staleness(self): snapshot = session.snapshot.return_value = self._make_snapshot() self.assertIs(batch_txn._get_snapshot(), snapshot) session.snapshot.assert_called_once_with( - read_timestamp=None, exact_staleness=duration, multi_use=True + read_timestamp=None, + exact_staleness=duration, + multi_use=True, + transaction_id=None, ) snapshot.begin.assert_called_once_with()