diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index dfbf33c1e8..b1ed2873ae 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -105,6 +105,8 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): ) if statement_type == ClientSideStatementType.RUN_PARTITIONED_QUERY: return connection.run_partitioned_query(parsed_statement) + if statement_type == ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE: + return connection._set_autocommit_dml_mode(parsed_statement) def _get_streamed_result_set(column_name, type_code, column_values): diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index 63188a032a..002779adb4 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -38,6 +38,9 @@ RE_RUN_PARTITIONED_QUERY = re.compile( r"^\s*(RUN)\s+(PARTITIONED)\s+(QUERY)\s+(.+)", re.IGNORECASE ) +RE_SET_AUTOCOMMIT_DML_MODE = re.compile( + r"^\s*(SET)\s+(AUTOCOMMIT_DML_MODE)\s+(=)\s+(.+)", re.IGNORECASE +) def parse_stmt(query): @@ -82,6 +85,10 @@ def parse_stmt(query): match = re.search(RE_RUN_PARTITION, query) client_side_statement_params.append(match.group(3)) client_side_statement_type = ClientSideStatementType.RUN_PARTITION + elif RE_SET_AUTOCOMMIT_DML_MODE.match(query): + match = re.search(RE_SET_AUTOCOMMIT_DML_MODE, query) + client_side_statement_params.append(match.group(4)) + client_side_statement_type = ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE if client_side_statement_type is not None: return ParsedStatement( StatementType.CLIENT_SIDE, diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 70b0f2cfbc..3dec2bd028 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -23,6 +23,7 @@ from google.cloud.spanner_dbapi.parse_utils import _get_statement_type from google.cloud.spanner_dbapi.parsed_statement import ( StatementType, + AutocommitDmlMode, ) from google.cloud.spanner_dbapi.partition_helper import PartitionId from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement @@ -116,6 +117,7 @@ def __init__(self, instance, database=None, read_only=False): self._batch_mode = BatchMode.NONE self._batch_dml_executor: BatchDmlExecutor = None self._transaction_helper = TransactionRetryHelper(self) + self._autocommit_dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL @property def spanner_client(self): @@ -167,6 +169,23 @@ def database(self): """ return self._database + @property + def autocommit_dml_mode(self): + """Modes for executing DML statements in autocommit mode for this connection. + + The DML autocommit modes are: + 1) TRANSACTIONAL - DML statements are executed as single read-write transaction. + After successful execution, the DML statement is guaranteed to have been applied + exactly once to the database. + + 2) PARTITIONED_NON_ATOMIC - DML statements are executed as partitioned DML transactions. + If an error occurs during the execution of the DML statement, it is possible that the + statement has been applied to some but not all of the rows specified in the statement. + + :rtype: :class:`~google.cloud.spanner_dbapi.parsed_statement.AutocommitDmlMode` + """ + return self._autocommit_dml_mode + @property @deprecated( reason="This method is deprecated. Use _spanner_transaction_started field" @@ -577,6 +596,37 @@ def run_partitioned_query( partitioned_query, statement.params, statement.param_types ) + @check_not_closed + def _set_autocommit_dml_mode( + self, + parsed_statement: ParsedStatement, + ): + autocommit_dml_mode_str = parsed_statement.client_side_statement_params[0] + autocommit_dml_mode = AutocommitDmlMode[autocommit_dml_mode_str.upper()] + self.set_autocommit_dml_mode(autocommit_dml_mode) + + def set_autocommit_dml_mode( + self, + autocommit_dml_mode, + ): + """ + Sets the mode for executing DML statements in autocommit mode for this connection. + This mode is only used when the connection is in autocommit mode, and may only + be set while the transaction is in autocommit mode and not in a temporary transaction. + """ + + if self._client_transaction_started is True: + raise ProgrammingError( + "Cannot set autocommit DML mode while not in autocommit mode or while a transaction is active." + ) + if self.read_only is True: + raise ProgrammingError( + "Cannot set autocommit DML mode for a read-only connection." + ) + if self._batch_mode is not BatchMode.NONE: + raise ProgrammingError("Cannot set autocommit DML mode while in a batch.") + self._autocommit_dml_mode = autocommit_dml_mode + def _partitioned_query_validation(self, partitioned_query, statement): if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY: raise ProgrammingError( diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 2bd46ab643..bd2ad974f9 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -45,6 +45,7 @@ StatementType, Statement, ParsedStatement, + AutocommitDmlMode, ) from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType from google.cloud.spanner_dbapi.utils import PeekIterator @@ -272,6 +273,17 @@ def _execute(self, sql, args=None, call_from_execute_many=False): self._batch_DDLs(sql) if not self.connection._client_transaction_started: self.connection.run_prior_DDL_statements() + elif ( + self.connection.autocommit_dml_mode + is AutocommitDmlMode.PARTITIONED_NON_ATOMIC + ): + self._row_count = self.connection.database.execute_partitioned_dml( + sql, + params=args, + param_types=self._parsed_statement.statement.param_types, + request_options=self.connection.request_options, + ) + self._result_set = None else: self._execute_in_rw_transaction() diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 1bb0ed25f4..f89d6ea19e 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -36,6 +36,12 @@ class ClientSideStatementType(Enum): PARTITION_QUERY = 9 RUN_PARTITION = 10 RUN_PARTITIONED_QUERY = 11 + SET_AUTOCOMMIT_DML_MODE = 12 + + +class AutocommitDmlMode(Enum): + TRANSACTIONAL = 1 + PARTITIONED_NON_ATOMIC = 2 @dataclass diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 52a80d5714..67854eeeac 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -26,6 +26,7 @@ OperationalError, RetryAborted, ) +from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode from google.cloud.spanner_v1 import JsonObject from google.cloud.spanner_v1 import gapic_version as package_version from google.api_core.datetime_helpers import DatetimeWithNanoseconds @@ -669,6 +670,27 @@ def test_run_partitioned_query(self): assert len(rows) == 10 self._conn.commit() + def test_partitioned_dml_query(self): + """Test partitioned_dml query works in autocommit mode.""" + self._cursor.execute("start batch dml") + for i in range(1, 11): + self._insert_row(i) + self._cursor.execute("run batch") + self._conn.commit() + + self._conn.autocommit = True + self._cursor.execute("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC") + self._cursor.execute("DELETE FROM contacts WHERE contact_id > 3") + assert self._cursor.rowcount == 7 + + self._cursor.execute("set autocommit_dml_mode = TRANSACTIONAL") + assert self._conn.autocommit_dml_mode == AutocommitDmlMode.TRANSACTIONAL + + self._conn.autocommit = False + # Test changing autocommit_dml_mode is not allowed when connection is in autocommit mode + with pytest.raises(ProgrammingError): + self._cursor.execute("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC") + def _insert_row(self, i): self._cursor.execute( f""" diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index c62d226a29..d0fa521f8f 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -33,6 +33,8 @@ ParsedStatement, StatementType, Statement, + ClientSideStatementType, + AutocommitDmlMode, ) PROJECT = "test-project" @@ -433,6 +435,62 @@ def test_abort_dml_batch(self, mock_batch_dml_executor): self.assertEqual(self._under_test._batch_mode, BatchMode.NONE) self.assertEqual(self._under_test._batch_dml_executor, None) + def test_set_autocommit_dml_mode_with_autocommit_false(self): + self._under_test.autocommit = False + parsed_statement = ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("sql"), + ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE, + ["PARTITIONED_NON_ATOMIC"], + ) + + with self.assertRaises(ProgrammingError): + self._under_test._set_autocommit_dml_mode(parsed_statement) + + def test_set_autocommit_dml_mode_with_readonly(self): + self._under_test.autocommit = True + self._under_test.read_only = True + parsed_statement = ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("sql"), + ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE, + ["PARTITIONED_NON_ATOMIC"], + ) + + with self.assertRaises(ProgrammingError): + self._under_test._set_autocommit_dml_mode(parsed_statement) + + def test_set_autocommit_dml_mode_with_batch_mode(self): + self._under_test.autocommit = True + parsed_statement = ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("sql"), + ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE, + ["PARTITIONED_NON_ATOMIC"], + ) + + self._under_test._set_autocommit_dml_mode(parsed_statement) + + assert ( + self._under_test.autocommit_dml_mode + == AutocommitDmlMode.PARTITIONED_NON_ATOMIC + ) + + def test_set_autocommit_dml_mode(self): + self._under_test.autocommit = True + parsed_statement = ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("sql"), + ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE, + ["PARTITIONED_NON_ATOMIC"], + ) + + self._under_test._set_autocommit_dml_mode(parsed_statement) + assert ( + self._under_test.autocommit_dml_mode + == AutocommitDmlMode.PARTITIONED_NON_ATOMIC + ) + @mock.patch("google.cloud.spanner_v1.database.Database", autospec=True) def test_run_prior_DDL_statements(self, mock_database): from google.cloud.spanner_dbapi import Connection, InterfaceError diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 239fc9d6b3..3a325014fa 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -115,6 +115,20 @@ def test_run_partitioned_query_classify_stmt(self): ), ) + def test_set_autocommit_dml_mode_stmt(self): + parsed_statement = classify_statement( + " set autocommit_dml_mode = PARTITIONED_NON_ATOMIC " + ) + self.assertEqual( + parsed_statement, + ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC"), + ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE, + ["PARTITIONED_NON_ATOMIC"], + ), + ) + @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner(self): from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner