Skip to content

Commit

Permalink
feat: Implementation for Begin and Rollback clientside statements (#1041
Browse files Browse the repository at this point in the history
)

* fix: Refactoring tests to use fixtures properly

* Not using autouse fixtures for few tests where not needed

* feat: Implementation for Begin and Rollback clientside statements

* Incorporating comments

* Formatting

* Comments incorporated

* Fixing tests

* Small fix

* Test fix as emulator was going OOM
  • Loading branch information
ankiaga committed Dec 4, 2023
1 parent aa36b07 commit 15623cd
Show file tree
Hide file tree
Showing 8 changed files with 824 additions and 733 deletions.
13 changes: 12 additions & 1 deletion google/cloud/spanner_dbapi/client_side_statement_executor.py
Expand Up @@ -11,19 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
)


def execute(connection, parsed_statement: ParsedStatement):
def execute(connection: "Connection", parsed_statement: ParsedStatement):
"""Executes the client side statements by calling the relevant method.
It is an internal method that can make backwards-incompatible changes.
:type connection: Connection
:param connection: Connection object of the dbApi
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
return connection.commit()
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
return connection.begin()
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
return connection.rollback()
10 changes: 10 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Expand Up @@ -20,7 +20,9 @@
ClientSideStatementType,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -39,4 +41,12 @@ def parse_stmt(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
if RE_BEGIN.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
)
if RE_ROLLBACK.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
)
return None
101 changes: 74 additions & 27 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -34,7 +34,9 @@
from google.rpc.code_pb2 import ABORTED


AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as transaction has not started"
)
MAX_INTERNAL_RETRIES = 50


Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(self, instance, database=None, read_only=False):
self._read_only = read_only
self._staleness = None
self.request_priority = None
self._transaction_begin_marked = False

@property
def autocommit(self):
Expand All @@ -122,7 +125,7 @@ def autocommit(self, value):
:type value: bool
:param value: New autocommit mode state.
"""
if value and not self._autocommit and self.inside_transaction:
if value and not self._autocommit and self._spanner_transaction_started:
self.commit()

self._autocommit = value
Expand All @@ -137,17 +140,35 @@ def database(self):
return self._database

@property
def inside_transaction(self):
"""Flag: transaction is started.
def _spanner_transaction_started(self):
"""Flag: whether transaction started at Spanner. This means that we had
made atleast one call to Spanner. Property client_transaction_started
would always be true if this is true as transaction has to start first
at clientside than at Spanner
Returns:
bool: True if transaction begun, False otherwise.
bool: True if Spanner transaction started, False otherwise.
"""
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
)
) or (self._snapshot is not None)

@property
def inside_transaction(self):
"""Deprecated property which won't be supported in future versions.
Please use spanner_transaction_started property instead."""
return self._spanner_transaction_started

@property
def _client_transaction_started(self):
"""Flag: whether transaction started at client side.
Returns:
bool: True if transaction started, False otherwise.
"""
return (not self._autocommit) or self._transaction_begin_marked

@property
def instance(self):
Expand Down Expand Up @@ -175,7 +196,7 @@ def read_only(self, value):
Args:
value (bool): True for ReadOnly mode, False for ReadWrite.
"""
if self.inside_transaction:
if self._spanner_transaction_started:
raise ValueError(
"Connection read/write mode can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
Expand Down Expand Up @@ -213,7 +234,7 @@ def staleness(self, value):
Args:
value (dict): Staleness type and value.
"""
if self.inside_transaction:
if self._spanner_transaction_started:
raise ValueError(
"`staleness` option can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
Expand Down Expand Up @@ -331,15 +352,16 @@ def transaction_checkout(self):
"""Get a Cloud Spanner transaction.
Begin a new transaction, if there is no transaction in
this connection yet. Return the begun one otherwise.
this connection yet. Return the started one otherwise.
The method is non operational in autocommit mode.
This method is a no-op if the connection is in autocommit mode and no
explicit transaction has been started
:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
:returns: A Cloud Spanner transaction object, ready to use.
"""
if not self.autocommit:
if not self.inside_transaction:
if not self.read_only and self._client_transaction_started:
if not self._spanner_transaction_started:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()

Expand All @@ -354,7 +376,7 @@ def snapshot_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot`
:returns: A Cloud Spanner snapshot object, ready to use.
"""
if self.read_only and not self.autocommit:
if self.read_only and self._client_transaction_started:
if not self._snapshot:
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
Expand All @@ -369,55 +391,80 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if self.inside_transaction:
if self._spanner_transaction_started and not self.read_only:
self._transaction.rollback()

if self._own_pool and self.database:
self.database._pool.clear()

self.is_closed = True

@check_not_closed
def begin(self):
"""
Marks the transaction as started.
:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already started")
if self._spanner_transaction_started:
raise OperationalError(
"Beginning a new transaction is not allowed when a transaction is already running"
)
self._transaction_begin_marked = True

def commit(self):
"""Commits any pending transaction to the database.
This method is non-operational in autocommit mode.
This is a no-op if there is no active client transaction.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self._snapshot = None

if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return

self.run_prior_DDL_statements()
if self.inside_transaction:
if self._spanner_transaction_started:
try:
if not self.read_only:
if self.read_only:
self._snapshot = None
else:
self._transaction.commit()

self._release_session()
self._statements = []
self._transaction_begin_marked = False
except Aborted:
self.retry_transaction()
self.commit()

def rollback(self):
"""Rolls back any pending transaction.
This is a no-op if there is no active transaction or if the connection
is in autocommit mode.
This is a no-op if there is no active client transaction.
"""
self._snapshot = None

if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
if not self.read_only:
if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return

if self._spanner_transaction_started:
if self.read_only:
self._snapshot = None
else:
self._transaction.rollback()

self._release_session()
self._statements = []
self._transaction_begin_marked = False

@check_not_closed
def cursor(self):
Expand Down
23 changes: 16 additions & 7 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -250,7 +250,7 @@ def execute(self, sql, args=None):
)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if self.connection.autocommit:
if not self.connection._client_transaction_started:
self.connection.run_prior_DDL_statements()
return

Expand All @@ -264,7 +264,7 @@ def execute(self, sql, args=None):

sql, args = sql_pyformat_args_to_spanner(sql, args or None)

if not self.connection.autocommit:
if self.connection._client_transaction_started:
statement = Statement(
sql,
args,
Expand Down Expand Up @@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params):
)
statements.append((sql, params, get_param_types(params)))

if self.connection.autocommit:
if not self.connection._client_transaction_started:
self.connection.database.run_in_transaction(
self._do_batch_update, statements, many_result_set
)
Expand Down Expand Up @@ -396,7 +396,10 @@ def fetchone(self):
sequence, or None when no more data is available."""
try:
res = next(self)
if not self.connection.autocommit and not self.connection.read_only:
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
return res
except StopIteration:
Expand All @@ -414,7 +417,10 @@ def fetchall(self):
res = []
try:
for row in self:
if not self.connection.autocommit and not self.connection.read_only:
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(row)
res.append(row)
except Aborted:
Expand Down Expand Up @@ -443,7 +449,10 @@ def fetchmany(self, size=None):
for _ in range(size):
try:
res = next(self)
if not self.connection.autocommit and not self.connection.read_only:
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
Expand Down Expand Up @@ -473,7 +482,7 @@ def _handle_DQL(self, sql, params):
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
if self.connection.read_only and not self.connection.autocommit:
if self.connection.read_only and self.connection._client_transaction_started:
# initiate or use the existing multi-use snapshot
self._handle_DQL_with_snapshot(
self.connection.snapshot_checkout(), sql, params
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Expand Up @@ -27,6 +27,7 @@ class StatementType(Enum):
class ClientSideStatementType(Enum):
COMMIT = 1
BEGIN = 2
ROLLBACK = 3


@dataclass
Expand Down

0 comments on commit 15623cd

Please sign in to comment.