From 1be0842f3a5083c4bc576af5a476fde614659473 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Wed, 18 Mar 2026 08:16:06 +0000 Subject: [PATCH 01/12] feat: initial scaffolding for the `google-cloud-spanner-dbapi-driver` package, including core files, tests, documentation, and build configurations. --- packages/google-cloud-spanner-dbapi-driver/README.rst | 2 +- packages/google-cloud-spanner-dbapi-driver/docs/README.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/README.rst b/packages/google-cloud-spanner-dbapi-driver/README.rst index 859400cc6da4..29d7be7da11e 100644 --- a/packages/google-cloud-spanner-dbapi-driver/README.rst +++ b/packages/google-cloud-spanner-dbapi-driver/README.rst @@ -3,7 +3,7 @@ Python DBAPI 2.0 Compliant Driver for Spanner |stable| |pypi| |versions| -.. |stable| image:: https://img.shields.io/badge/support-preview-orange.svg +.. |stable| image:: https://img.shields.io/badge/support-stable-gold.svg :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#stability-levels .. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-spanner-dbapi-driver.svg :target: https://pypi.org/project/google-cloud-spanner-dbapi-driver/ diff --git a/packages/google-cloud-spanner-dbapi-driver/docs/README.rst b/packages/google-cloud-spanner-dbapi-driver/docs/README.rst index 859400cc6da4..29d7be7da11e 100644 --- a/packages/google-cloud-spanner-dbapi-driver/docs/README.rst +++ b/packages/google-cloud-spanner-dbapi-driver/docs/README.rst @@ -3,7 +3,7 @@ Python DBAPI 2.0 Compliant Driver for Spanner |stable| |pypi| |versions| -.. |stable| image:: https://img.shields.io/badge/support-preview-orange.svg +.. |stable| image:: https://img.shields.io/badge/support-stable-gold.svg :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#stability-levels .. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-spanner-dbapi-driver.svg :target: https://pypi.org/project/google-cloud-spanner-dbapi-driver/ From 141901fa1e977f94c154eb17d04af1ffdeaf52a1 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Mon, 23 Mar 2026 17:05:39 +0530 Subject: [PATCH 02/12] Update packages/google-cloud-spanner-dbapi-driver/README.rst Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- packages/google-cloud-spanner-dbapi-driver/README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/README.rst b/packages/google-cloud-spanner-dbapi-driver/README.rst index 29d7be7da11e..859400cc6da4 100644 --- a/packages/google-cloud-spanner-dbapi-driver/README.rst +++ b/packages/google-cloud-spanner-dbapi-driver/README.rst @@ -3,7 +3,7 @@ Python DBAPI 2.0 Compliant Driver for Spanner |stable| |pypi| |versions| -.. |stable| image:: https://img.shields.io/badge/support-stable-gold.svg +.. |stable| image:: https://img.shields.io/badge/support-preview-orange.svg :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#stability-levels .. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-spanner-dbapi-driver.svg :target: https://pypi.org/project/google-cloud-spanner-dbapi-driver/ From 31a8aad305358cf8ba870753352e3610d6285dc0 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Mon, 23 Mar 2026 17:05:52 +0530 Subject: [PATCH 03/12] Update packages/google-cloud-spanner-dbapi-driver/docs/README.rst Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- packages/google-cloud-spanner-dbapi-driver/docs/README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/docs/README.rst b/packages/google-cloud-spanner-dbapi-driver/docs/README.rst index 29d7be7da11e..859400cc6da4 100644 --- a/packages/google-cloud-spanner-dbapi-driver/docs/README.rst +++ b/packages/google-cloud-spanner-dbapi-driver/docs/README.rst @@ -3,7 +3,7 @@ Python DBAPI 2.0 Compliant Driver for Spanner |stable| |pypi| |versions| -.. |stable| image:: https://img.shields.io/badge/support-stable-gold.svg +.. |stable| image:: https://img.shields.io/badge/support-preview-orange.svg :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#stability-levels .. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-spanner-dbapi-driver.svg :target: https://pypi.org/project/google-cloud-spanner-dbapi-driver/ From 09f34d91c608786fcfc6aca3f59f9b5caa836a09 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Tue, 24 Mar 2026 07:26:48 +0000 Subject: [PATCH 04/12] feat: Implement initial Google Cloud Spanner DB-API 2.0 driver with core components and comprehensive unit and system tests. --- .../google/cloud/spanner_driver/__init__.py | 59 ++- .../google/cloud/spanner_driver/connection.py | 138 +++++ .../google/cloud/spanner_driver/cursor.py | 475 ++++++++++++++++++ .../google/cloud/spanner_driver/errors.py | 241 +++++++++ .../google/cloud/spanner_driver/types.py | 170 +++++++ .../noxfile.py | 1 + .../tests/system/_helper.py | 58 +++ .../tests/system/test_connection.py | 44 ++ .../tests/system/test_cursor.py | 144 ++++++ .../tests/system/test_errors.py | 74 +++ .../tests/system/test_executemany.py | 64 +++ .../tests/system/test_transaction.py | 116 +++++ .../tests/unit/conftest.py | 221 ++++++++ .../tests/unit/test_connection.py | 111 ++++ .../tests/unit/test_cursor.py | 349 +++++++++++++ .../tests/unit/test_errors.py | 57 +++ .../tests/unit/test_types.py | 57 +++ 17 files changed, 2376 insertions(+), 3 deletions(-) create mode 100644 packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/types.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/_helper.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_connection.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_errors.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_executemany.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/system/test_transaction.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/conftest.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/test_errors.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/unit/test_types.py diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py index b75f2a4d398f..d898b418c6f5 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py @@ -14,17 +14,70 @@ """Spanner Python Driver.""" import logging +from typing import Final -from . import version as package_version +from .connection import Connection, connect +from .cursor import Cursor from .dbapi import apilevel, paramstyle, threadsafety +from .errors import ( + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, +) +from .types import ( + BINARY, + DATETIME, + NUMBER, + ROWID, + STRING, + Binary, + Date, + DateFromTicks, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, +) -__version__ = package_version.__version__ +__version__: Final[str] = "0.0.1" logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) __all__: list[str] = [ "apilevel", - "paramstyle", "threadsafety", + "paramstyle", + "Connection", + "connect", + "Cursor", + "Date", + "Time", + "Timestamp", + "DateFromTicks", + "TimeFromTicks", + "TimestampFromTicks", + "Binary", + "STRING", + "BINARY", + "NUMBER", + "DATETIME", + "ROWID", + "InterfaceError", + "ProgrammingError", + "OperationalError", + "DatabaseError", + "DataError", + "NotSupportedError", + "IntegrityError", + "InternalError", + "Warning", + "Error", ] diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py new file mode 100644 index 000000000000..12e4c3638d98 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py @@ -0,0 +1,138 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any + +from google.cloud.spannerlib.pool import Pool + +from . import errors +from .cursor import Cursor + +logger = logging.getLogger(__name__) + + +def check_not_closed(function): + """`Connection` class methods decorator. + + Raise an exception if the connection is closed. + + :raises: :class:`InterfaceError` if the connection is closed. + """ + + def wrapper(connection, *args, **kwargs): + if connection._closed: + raise errors.InterfaceError("Connection is closed") + + return function(connection, *args, **kwargs) + + return wrapper + + +class Connection: + """Connection to a Google Cloud Spanner database. + + This class provides a connection to the Spanner database and adheres to + PEP 249 (Python Database API Specification v2.0). + """ + + def __init__(self, internal_connection: Any): + """ + Args: + internal_connection: An instance of + google.cloud.spannerlib.Connection + """ + self._internal_conn = internal_connection + self._closed = False + self._messages: list[Any] = [] + + @property + def messages(self) -> list[Any]: + """Return the list of messages sent to the client by the database.""" + return self._messages + + @check_not_closed + def cursor(self) -> Cursor: + """Return a new Cursor Object using the connection. + + Returns: + Cursor: A cursor object. + """ + return Cursor(self) + + @check_not_closed + def begin(self) -> None: + """Begin a new transaction.""" + logger.debug("Beginning transaction") + try: + self._internal_conn.begin_transaction() + except Exception as e: + raise errors.map_spanner_error(e) + + @check_not_closed + def commit(self) -> None: + """Commit any pending transaction to the database. + + This is a no-op if there is no active client transaction. + """ + logger.debug("Committing transaction") + try: + self._internal_conn.commit() + except Exception as e: + # raise errors.map_spanner_error(e) + logger.debug(f"Commit failed {e}") + + @check_not_closed + def rollback(self) -> None: + """Rollback any pending transaction to the database. + + This is a no-op if there is no active client transaction. + """ + logger.debug("Rolling back transaction") + try: + self._internal_conn.rollback() + except Exception as e: + # raise errors.map_spanner_error(e) + logger.debug(f"Rollback failed {e}") + + def close(self) -> None: + """Close the connection now. + + The connection will be unusable from this point forward; an Error (or + subclass) exception will be raised if any operation is attempted with + the connection. The same applies to all cursor objects trying to use + the connection. + """ + if self._closed: + raise errors.InterfaceError("Connection is already closed") + + logger.debug("Closing connection") + self._internal_conn.close() + self._closed = True + + def __enter__(self) -> "Connection": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + +def connect(connection_string: str, **kwargs: Any) -> Connection: + logger.debug(f"Connecting to {connection_string}") + # Create the pool + pool = Pool.create_pool(connection_string) + + # Create the low-level connection + internal_conn = pool.create_connection() + + return Connection(internal_conn) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py new file mode 100644 index 000000000000..a81e95ef47e8 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py @@ -0,0 +1,475 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import datetime +import logging +import uuid +from enum import Enum +from typing import TYPE_CHECKING, Any + +from google.cloud.spanner_v1 import ( + ExecuteBatchDmlRequest, + ExecuteSqlRequest, + Type, + TypeCode, +) + +from . import errors +from .types import _type_code_to_dbapi_type + +if TYPE_CHECKING: + from .connection import Connection + +logger = logging.getLogger(__name__) + + +def check_not_closed(function): + """`Cursor` class methods decorator. + + Raise an exception if the cursor is closed. + + :raises: :class:`InterfaceError` if the cursor is closed. + """ + + def wrapper(cursor, *args, **kwargs): + if cursor._closed: + raise errors.InterfaceError("Cursor is closed") + + return function(cursor, *args, **kwargs) + + return wrapper + + +class FetchScope(Enum): + FETCH_ONE = 1 + FETCH_MANY = 2 + FETCH_ALL = 3 + + +class Cursor: + """Cursor object for the Google Cloud Spanner database. + + This class lets you use a cursor to interact with the database. + """ + + def __init__(self, connection: "Connection"): + self._connection = connection + self._rows: Any = None # Holds the google.cloud.spannerlib.rows.Rows object + self._closed = False + self.arraysize = 1 + self._rowcount = -1 + + @property + def description(self) -> tuple[tuple[Any, ...], ...] | None: + """ + This read-only attribute is a sequence of 7-item sequences. + + Each of these sequences contains information describing one result + column: + - name + - type_code + - display_size + - internal_size + - precision + - scale + - null_ok + + The first two items (name and type_code) are mandatory, the other + five are optional and are set to None if no meaningful values can be + provided. + + This attribute will be None for operations that do not return rows or + if the cursor has not had an operation invoked via the .execute*() + method yet. + """ + logger.debug("Fetching description for cursor") + if not self._rows: + return None + + try: + metadata = self._rows.metadata() + if not metadata or not metadata.row_type: + return None + + desc = [] + for field in metadata.row_type.fields: + desc.append( + ( + field.name, + _type_code_to_dbapi_type(field.type.code), + None, # display_size + None, # internal_size + None, # precision + None, # scale + True, # null_ok + ) + ) + return tuple(desc) + except Exception: + return None + + @property + def rowcount(self) -> int: + """ + This read-only attribute specifies the number of rows that the last + .execute*() produced (for DQL statements like 'select') or affected + (for DML statements like 'update' or 'insert'). + + The attribute is -1 in case no .execute*() has been performed on the + cursor or the rowcount of the last operation cannot be determined by + the interface. + """ + return self._rowcount + + def _prepare_params( + self, parameters: dict[str, Any] | list[Any] | tuple[Any] | None = None + ) -> (dict[str, Any] | None, dict[str, Type] | None): + """ + Prepares parameters for Spanner execution + + Args: + parameters: A dictionary (for named parameters/GoogleSQL) + or a list/tuple + (for positional parameters/PostgreSQL). + + Returns: + A tuple containing: + - converted_params: Dictionary of parameters with values + converted for Spanner (e.g. ints to strings). + - param_types: Dictionary mapping parameter names to + their Spanner Type. + """ + if not parameters: + return {}, {} + + converted_params = {} + param_types = {} + + # Normalize input to an iterable of (key, value) + if isinstance(parameters, (list, tuple)): + # PostgreSQL Dialect: Positional parameters $1, $2... are + # mapped to P1, P2... + iterator = ((f"P{i}", val) for i, val in enumerate(parameters, 1)) + elif isinstance(parameters, dict): + # GoogleSQL Dialect: Named parameters @name are mapped directly. + iterator = parameters.items() + else: + # If strictly required, raise an error for unsupported types + return {}, {} + + for key, value in iterator: + if value is None: + converted_params[key] = None + continue + # Note: check bool before int, as bool is a subclass of int + if isinstance(value, bool): + converted_params[key] = value + param_types[key] = Type(code=TypeCode.BOOL) + elif isinstance(value, int): + # Spanner expects INT64 as strings to preserve precision + # in JSON + converted_params[key] = str(value) + param_types[key] = Type(code=TypeCode.INT64) + elif isinstance(value, float): + converted_params[key] = value + param_types[key] = Type(code=TypeCode.FLOAT64) + elif isinstance(value, bytes): + converted_params[key] = value + param_types[key] = Type(code=TypeCode.BYTES) + elif isinstance(value, uuid.UUID): + # Convert UUID to string as requested + converted_params[key] = str(value) + # Use STRING type for UUIDs (unless specific UUID type is + # required/supported by your backend version) + param_types[key] = Type(code=TypeCode.STRING) + elif isinstance(value, datetime.datetime): + # Convert Datetime to string (RFC 3339 format is standard + # for str(datetime)) + converted_params[key] = str(value) + param_types[key] = Type(code=TypeCode.TIMESTAMP) + elif isinstance(value, datetime.date): + converted_params[key] = str(value) + param_types[key] = Type(code=TypeCode.DATE) + else: + # Fallback for strings and other types + converted_params[key] = value + # For strings, we can explicitly set the type or let it default. + if isinstance(value, str): + param_types[key] = Type(code=TypeCode.STRING) + + return converted_params, param_types + + @check_not_closed + def execute( + self, + operation: str, + parameters: dict[str, Any] | list[Any] | tuple[Any] | None = None, + ) -> None: + """Prepare and execute a database operation (query or command). + + Parameters may be provided as sequence or mapping and will be bound to + variables in the operation. Variables are specified in a + database-specific notation (see the module's paramstyle attribute for + details). + + Args: + operation (str): The SQL statement to execute. + parameters (dict | list | tuple, optional): parameters to bind. + """ + logger.debug(f"Executing operation: {operation}") + + request = ExecuteSqlRequest(sql=operation) + params, _ = self._prepare_params(parameters) + request.params = params + + try: + self._rows = self._connection._internal_conn.execute(request) + + if self.description: + self._rowcount = -1 + else: + update_count = self._rows.update_count() + if update_count != -1: + self._rowcount = update_count + self._rows.close() + self._rows = None + + except Exception as e: + raise errors.map_spanner_error(e) from e + + @check_not_closed + def executemany( + self, + operation: str, + seq_of_parameters: (list[dict[str, Any]] | list[list[Any]] | list[tuple[Any]]), + ) -> None: + """Prepare a database operation (query or command) and then execute it + against all parameter sequences or mappings found in the sequence + seq_of_parameters. + + Args: + operation (str): The SQL statement to execute. + seq_of_parameters (list): A list of parameter sequences/mappings. + """ + logger.debug(f"Executing batch operation: {operation}") + + request = ExecuteBatchDmlRequest() + + for parameters in seq_of_parameters: + statement = ExecuteBatchDmlRequest.Statement(sql=operation) + params, _ = self._prepare_params(parameters) + statement.params = params + + request.statements.append(statement) + + try: + response = self._connection._internal_conn.execute_batch(request) + total_rowcount = 0 + for result_set in response.result_sets: + if result_set.stats.row_count_exact != -1: + total_rowcount += result_set.stats.row_count_exact + elif result_set.stats.row_count_lower_bound != -1: + total_rowcount += result_set.stats.row_count_lower_bound + self._rowcount = total_rowcount + + except Exception as e: + raise errors.map_spanner_error(e) from e + + def _convert_value(self, value: Any, field_type: Any) -> Any: + kind = value.WhichOneof("kind") + if kind == "null_value": + return None + if kind == "bool_value": + return value.bool_value + if kind == "number_value": + return value.number_value + if kind == "string_value": + code = field_type.code + val = value.string_value + if code == TypeCode.INT64: + return int(val) + if code == TypeCode.BYTES or code == TypeCode.PROTO: + return base64.b64decode(val) + return val + if kind == "list_value": + return [ + self._convert_value(v, field_type.array_element_type) + for v in value.list_value.values + ] + # Fallback for complex types (structs) not fully mapped yet + return value + + def _convert_row(self, row: Any) -> tuple[Any, ...]: + metadata = self._rows.metadata() + fields = metadata.row_type.fields + converted = [] + for i, value in enumerate(row.values): + converted.append(self._convert_value(value, fields[i].type)) + return tuple(converted) + + def _fetch( + self, scope: FetchScope, size: int | None = None + ) -> list[tuple[Any, ...]]: + if not self._rows: + raise errors.ProgrammingError("No result set available") + try: + rows = [] + if scope == FetchScope.FETCH_ONE: + try: + row = self._rows.next() + if row is not None: + rows.append(self._convert_row(row)) + except StopIteration: + pass + elif scope == FetchScope.FETCH_MANY: + # size is guaranteed to be int if scope is FETCH_MANY and + # called from fetchmany but might be None if internal logic + # changes, strict check would satisfy type checker + limit = size if size is not None else self.arraysize + for _ in range(limit): + try: + row = self._rows.next() + if row is None: + break + rows.append(self._convert_row(row)) + except StopIteration: + break + elif scope == FetchScope.FETCH_ALL: + while True: + try: + row = self._rows.next() + if row is None: + break + rows.append(self._convert_row(row)) + except StopIteration: + break + except Exception as e: + raise errors.map_spanner_error(e) from e + + return rows + + @check_not_closed + def fetchone(self) -> tuple[Any, ...] | None: + """Fetch the next row of a query result set, returning a single + sequence, or None when no more data is available. + + Returns: + tuple | None: A row of data or None. + """ + logger.debug("Fetching one row") + rows = self._fetch(FetchScope.FETCH_ONE) + if not rows: + return None + return rows[0] + + @check_not_closed + def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]: + """Fetch the next set of rows of a query result, returning a sequence + of sequences (e.g. a list of tuples). An empty sequence is returned + when no more rows are available. + + The number of rows to fetch per call is specified by the parameter. If + it is not given, the cursor's arraysize determines the number of rows + to be fetched. + + Args: + size (int, optional): The number of rows to fetch. + + Returns: + list[tuple]: A list of rows. + """ + logger.debug("Fetching many rows") + if size is None: + size = self.arraysize + return self._fetch(FetchScope.FETCH_MANY, size) + + @check_not_closed + def fetchall(self) -> list[tuple[Any, ...]]: + """Fetch all (remaining) rows of a query result, returning them as a + sequence of sequences (e.g. a list of tuples). + + Returns: + list[tuple]: A list of rows. + """ + logger.debug("Fetching all rows") + return self._fetch(FetchScope.FETCH_ALL) + + def close(self) -> None: + """Close the cursor now. + + The cursor will be unusable from this point forward; an Error (or + subclass) exception will be raised if any operation is attempted with + the cursor. + """ + logger.debug("Closing cursor") + self._closed = True + if self._rows: + self._rows.close() + + @check_not_closed + def nextset(self) -> bool | None: + """Skip to the next available set of results.""" + logger.debug("Fetching next set of results") + if not self._rows: + return None + + try: + next_metadata = self._rows.next_result_set() + if next_metadata: + return True + return None + except Exception: + return None + + def __enter__(self) -> "Cursor": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + def __iter__(self) -> "Cursor": + return self + + def __next__(self) -> tuple[Any, ...]: + row = self.fetchone() + if row is None: + raise StopIteration + return row + + @check_not_closed + def setinputsizes(self, sizes: list[Any]) -> None: + """Predefine memory areas for parameters. + This operation is a no-op implementation. + """ + logger.debug("NO-OP: Setting input sizes") + pass + + @check_not_closed + def setoutputsize(self, size: int, column: int | None = None) -> None: + """Set a column buffer size. + This operation is a no-op implementation. + """ + logger.debug("NO-OP: Setting output size") + pass + + @check_not_closed + def callproc( + self, procname: str, parameters: list[Any] | tuple[Any] | None = None + ) -> None: + """Call a stored database procedure with the given name. + + This method is not supported by Spanner. + """ + logger.debug("NO-OP: Calling stored procedure") + raise errors.NotSupportedError("Stored procedures are not supported.") diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py new file mode 100644 index 000000000000..8225d374eee8 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py @@ -0,0 +1,241 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Spanner Python Driver Errors. + +DBAPI-defined Exceptions are defined in the following hierarchy:: + + Exceptions + |__Warning + |__Error + |__InterfaceError + |__DatabaseError + |__DataError + |__OperationalError + |__IntegrityError + |__InternalError + |__ProgrammingError + |__NotSupportedError + +""" + +from typing import Any, Sequence + +from google.api_core.exceptions import GoogleAPICallError + + +class Warning(Exception): + """Important DB API warning.""" + + pass + + +class Error(Exception): + """The base class for all the DB API exceptions. + + Does not include :class:`Warning`. + """ + + def _is_error_cause_instance_of_google_api_exception(self) -> bool: + return isinstance(self.__cause__, GoogleAPICallError) + + @property + def reason(self) -> str | None: + """The reason of the error. + Reference: + https://cloud.google.com/apis/design/errors#error_info + Returns: + Union[str, None]: An optional string containing reason of the error. + """ + return ( + self.__cause__.reason + if self._is_error_cause_instance_of_google_api_exception() + else None + ) + + @property + def domain(self) -> str | None: + """The logical grouping to which the "reason" belongs. + Reference: + https://cloud.google.com/apis/design/errors#error_info + Returns: + Union[str, None]: An optional string containing a logical grouping + to which the "reason" belongs. + """ + return ( + self.__cause__.domain + if self._is_error_cause_instance_of_google_api_exception() + else None + ) + + @property + def metadata(self) -> dict[str, str] | None: + """Additional structured details about this error. + Reference: + https://cloud.google.com/apis/design/errors#error_info + Returns: + Union[Dict[str, str], None]: An optional object containing + structured details about the error. + """ + return ( + self.__cause__.metadata + if self._is_error_cause_instance_of_google_api_exception() + else None + ) + + @property + def details(self) -> Sequence[Any] | None: + """Information contained in google.rpc.status.details. + Reference: + https://cloud.google.com/apis/design/errors#error_model + https://cloud.google.com/apis/design/errors#error_details + Returns: + Sequence[Any]: A list of structured objects from + error_details.proto + """ + return ( + self.__cause__.details + if self._is_error_cause_instance_of_google_api_exception() + else None + ) + + +class InterfaceError(Error): + """ + Error related to the database interface + rather than the database itself. + """ + + pass + + +class DatabaseError(Error): + """Error related to the database.""" + + pass + + +class DataError(DatabaseError): + """ + Error due to problems with the processed data like + division by zero, numeric value out of range, etc. + """ + + pass + + +class OperationalError(DatabaseError): + """ + Error related to the database's operation, e.g. an + unexpected disconnect, the data source name is not + found, a transaction could not be processed, a + memory allocation error, etc. + """ + + pass + + +class IntegrityError(DatabaseError): + """ + Error for cases of relational integrity of the database + is affected, e.g. a foreign key check fails. + """ + + pass + + +class InternalError(DatabaseError): + """ + Internal database error, e.g. the cursor is not valid + anymore, the transaction is out of sync, etc. + """ + + pass + + +class ProgrammingError(DatabaseError): + """ + Programming error, e.g. table not found or already + exists, syntax error in the SQL statement, wrong + number of parameters specified, etc. + """ + + pass + + +class NotSupportedError(DatabaseError): + """ + Error for case of a method or database API not + supported by the database was used. + """ + + pass + + +def map_spanner_error(error: Exception) -> Error: + """Map SpannerLibError or GoogleAPICallError to DB API 2.0 errors.""" + from google.api_core import exceptions + from google.cloud.spannerlib.internal.errors import SpannerLibError + + match error: + # Handle SpannerLibError by matching on the internal + # error_code attribute + case SpannerLibError(error_code=code): + match code: + # 3 - INVALID_ARGUMENT + # 5 - NOT_FOUND + case 3 | 5: + return ProgrammingError(error) + # 6 - ALREADY_EXISTS + case 6: + return IntegrityError(error) + # 11 - OUT_OF_RANGE + case 11: + return DataError(error) + # 1 - CANCELLED + # 4 - DEADLINE_EXCEEDED + # 7 - PERMISSION_DENIED + # 9 - FAILED_PRECONDITION + # 10 - ABORTED + # 14 - INTERNAL + # 16 - UNAUTHENTICATED + case 1 | 4 | 7 | 9 | 10 | 14 | 16: + return OperationalError(error) + # 13 - INTERNAL + case 13: + return InternalError(error) + case _: + return DatabaseError(error) + + # Handle standard api_core exceptions + case exceptions.InvalidArgument() | exceptions.NotFound(): + return ProgrammingError(error) + case exceptions.AlreadyExists(): + return IntegrityError(error) + case exceptions.OutOfRange(): + return DataError(error) + case ( + exceptions.FailedPrecondition() + | exceptions.Unauthenticated() + | exceptions.PermissionDenied() + | exceptions.DeadlineExceeded() + | exceptions.ServiceUnavailable() + | exceptions.Aborted() + | exceptions.Cancelled() + ): + return OperationalError(error) + case exceptions.InternalServerError(): + return InternalError(error) + case _: + return DatabaseError(error) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/types.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/types.py new file mode 100644 index 000000000000..3b3d228ee743 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/types.py @@ -0,0 +1,170 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Types.""" + +import datetime +from typing import Any + +from google.cloud.spanner_v1 import TypeCode + + +def Date(year: int, month: int, day: int) -> datetime.date: + """Construct a date object. + + Args: + year (int): The year of the date. + month (int): The month of the date. + day (int): The day of the date. + + Returns: + datetime.date: A date object. + """ + return datetime.date(year, month, day) + + +def Time(hour: int, minute: int, second: int) -> datetime.time: + """Construct a time object. + + Args: + hour (int): The hour of the time. + minute (int): The minute of the time. + second (int): The second of the time. + + Returns: + datetime.time: A time object. + """ + return datetime.time(hour, minute, second) + + +def Timestamp( + year: int, month: int, day: int, hour: int, minute: int, second: int +) -> datetime.datetime: + """Construct a timestamp object. + + Args: + year (int): The year of the timestamp. + month (int): The month of the timestamp. + day (int): The day of the timestamp. + hour (int): The hour of the timestamp. + minute (int): The minute of the timestamp. + second (int): The second of the timestamp. + + Returns: + datetime.datetime: A timestamp object. + """ + return datetime.datetime(year, month, day, hour, minute, second) + + +def DateFromTicks(ticks: float) -> datetime.date: + """Construct a date object from ticks. + + Args: + ticks (float): The number of seconds since the epoch. + + Returns: + datetime.date: A date object. + """ + return datetime.date.fromtimestamp(ticks) + + +def TimeFromTicks(ticks: float) -> datetime.time: + """Construct a time object from ticks. + + Args: + ticks (float): The number of seconds since the epoch. + + Returns: + datetime.time: A time object. + """ + return datetime.datetime.fromtimestamp(ticks).time() + + +def TimestampFromTicks(ticks: float) -> datetime.datetime: + """Construct a timestamp object from ticks. + + Args: + ticks (float): The number of seconds since the epoch. + + Returns: + datetime.datetime: A timestamp object. + """ + return datetime.datetime.fromtimestamp(ticks) + + +def Binary(string: str | bytes) -> bytes: + """Construct a binary object. + + Args: + string (str | bytes): The string or bytes to convert. + + Returns: + bytes: A binary object. + """ + return bytes(string, "utf-8") if isinstance(string, str) else bytes(string) + + +# Type Objects for description comparison +class DBAPITypeObject: + def __init__(self, *values: str): + self.values = values + + def __eq__(self, other: Any) -> bool: + return other in self.values + + +STRING = DBAPITypeObject("STRING") +BINARY = DBAPITypeObject("BYTES", "PROTO") +NUMBER = DBAPITypeObject("INT64", "FLOAT64", "NUMERIC") +DATETIME = DBAPITypeObject("TIMESTAMP", "DATE") +BOOLEAN = DBAPITypeObject("BOOL") +ROWID = DBAPITypeObject() + + +class Type(object): + STRING = TypeCode.STRING + BYTES = TypeCode.BYTES + BOOL = TypeCode.BOOL + INT64 = TypeCode.INT64 + FLOAT64 = TypeCode.FLOAT64 + DATE = TypeCode.DATE + TIMESTAMP = TypeCode.TIMESTAMP + NUMERIC = TypeCode.NUMERIC + JSON = TypeCode.JSON + PROTO = TypeCode.PROTO + ENUM = TypeCode.ENUM + + +def _type_code_to_dbapi_type(type_code: int) -> DBAPITypeObject: + if type_code == TypeCode.STRING: + return STRING + if type_code == TypeCode.JSON: + return STRING + if type_code == TypeCode.BYTES: + return BINARY + if type_code == TypeCode.PROTO: + return BINARY + if type_code == TypeCode.BOOL: + return BOOLEAN + if type_code == TypeCode.INT64: + return NUMBER + if type_code == TypeCode.FLOAT64: + return NUMBER + if type_code == TypeCode.NUMERIC: + return NUMBER + if type_code == TypeCode.DATE: + return DATETIME + if type_code == TypeCode.TIMESTAMP: + return DATETIME + + return STRING diff --git a/packages/google-cloud-spanner-dbapi-driver/noxfile.py b/packages/google-cloud-spanner-dbapi-driver/noxfile.py index 3be4dcc7e55c..2fedee7ee5af 100644 --- a/packages/google-cloud-spanner-dbapi-driver/noxfile.py +++ b/packages/google-cloud-spanner-dbapi-driver/noxfile.py @@ -64,6 +64,7 @@ "pytest", "pytest-cov", "pytest-asyncio", + "google-cloud-spanner", ] UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [] UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/_helper.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/_helper.py new file mode 100644 index 000000000000..e940f79c01ee --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/_helper.py @@ -0,0 +1,58 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for system tests.""" + +import os + +SPANNER_EMULATOR_HOST = os.environ.get("SPANNER_EMULATOR_HOST") +TEST_ON_PROD = not bool(SPANNER_EMULATOR_HOST) + +if TEST_ON_PROD: + PROJECT_ID = os.environ.get("SPANNER_PROJECT_ID") + INSTANCE_ID = os.environ.get("SPANNER_INSTANCE_ID") + DATABASE_ID = os.environ.get("SPANNER_DATABASE_ID") + + if not PROJECT_ID or not INSTANCE_ID or not DATABASE_ID: + raise ValueError( + "SPANNER_PROJECT_ID, SPANNER_INSTANCE_ID, and SPANNER_DATABASE_ID " + "must be set when running tests on production." + ) +else: + PROJECT_ID = "test-project" + INSTANCE_ID = "test-instance" + DATABASE_ID = "test-db" + +PROD_TEST_CONNECTION_STRING = ( + f"projects/{PROJECT_ID}/instances/{INSTANCE_ID}/databases/{DATABASE_ID}" +) + +EMULATOR_TEST_CONNECTION_STRING = ( + f"{SPANNER_EMULATOR_HOST}" + f"projects/{PROJECT_ID}" + f"/instances/{INSTANCE_ID}" + f"/databases/{DATABASE_ID}" + "?autoConfigEmulator=true" +) + + +def setup_test_env() -> None: + if not TEST_ON_PROD: + print(f"Set SPANNER_EMULATOR_HOST to {os.environ['SPANNER_EMULATOR_HOST']}") + print(f"Using Connection String: {get_test_connection_string()}") + + +def get_test_connection_string() -> str: + if TEST_ON_PROD: + return PROD_TEST_CONNECTION_STRING + return EMULATOR_TEST_CONNECTION_STRING diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_connection.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_connection.py new file mode 100644 index 000000000000..9fb4199f1b80 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_connection.py @@ -0,0 +1,44 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for connection.py""" + +from google.cloud.spanner_driver import connect + +from ._helper import get_test_connection_string + + +class TestConnect: + def test_cursor(self): + """Test the connect method.""" + connection_string = get_test_connection_string() + + # Test Context Manager + with connect(connection_string) as connection: + assert connection is not None + + # Test Cursor Context Manager + with connection.cursor() as cursor: + assert cursor is not None + + +class TestConnectMethod: + """Tests for the connection.py module.""" + + def test_connect(self): + """Test the connect method.""" + connection_string = get_test_connection_string() + + # Test Context Manager + with connect(connection_string) as connection: + assert connection is not None diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py new file mode 100644 index 000000000000..5719b4030fa5 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py @@ -0,0 +1,144 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for cursor.py""" + +from google.cloud.spanner_driver import connect, types + +from ._helper import get_test_connection_string + + +class TestCursor: + def test_execute(self): + """Test the execute method.""" + connection_string = get_test_connection_string() + + # Test Context Manager + with connect(connection_string) as connection: + assert connection is not None + + # Test Cursor Context Manager + with connection.cursor() as cursor: + assert cursor is not None + + # Test execute and fetchone + cursor.execute("SELECT 1 AS col1") + assert cursor.description is not None + assert cursor.description[0][0] == "col1" + assert ( + cursor.description[0][1] == types.NUMBER + ) # TypeCode.INT64 maps to types.NUMBER + + result = cursor.fetchone() + assert result == (1,) + + def test_execute_params(self): + """Test the execute method with parameters.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + sql = "SELECT @a AS col1" + params = {"a": 1} + cursor.execute(sql, params) + result = cursor.fetchone() + assert result == (1,) + + def test_execute_dml(self): + """Test DML execution.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS Singers") + + # Create table + cursor.execute( + """ + CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + ) PRIMARY KEY (SingerId) + """ + ) + + # Insert + cursor.execute( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)", + {"id": 1, "first": "John", "last": "Doe"}, + ) + assert cursor.rowcount == 1 + + # Update + cursor.execute( + "UPDATE Singers SET FirstName = 'Jane' WHERE SingerId = 1" + ) + assert cursor.rowcount == 1 + + # Select back to verify + cursor.execute("SELECT FirstName FROM Singers WHERE SingerId = 1") + row = cursor.fetchone() + assert row == ("Jane",) + + # Cleanup (optional if emulator is reset) + + def test_fetch_methods(self): + """Test fetchmany and fetchall.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + # Use UNNEST to generate rows + cursor.execute( + "SELECT * FROM UNNEST([1, 2, 3, 4, 5]) AS numbers ORDER BY numbers" + ) + + # Fetch one + row = cursor.fetchone() + assert row == (1,) + + # Fetch many + rows = cursor.fetchmany(2) + assert len(rows) == 2 + assert rows[0] == (2,) + assert rows[1] == (3,) + + # Fetch all remaining + rows = cursor.fetchall() + assert len(rows) == 2 + assert rows[0] == (4,) + assert rows[1] == (5,) + + def test_data_types(self): + """Test various data types.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + sql = """ + SELECT + 1 AS int_val, + 3.14 AS float_val, + TRUE AS bool_val, + 'hello' AS str_val, + b'bytes' AS bytes_val, + DATE '2023-01-01' AS date_val, + TIMESTAMP '2023-01-01T12:00:00Z' AS timestamp_val + """ + cursor.execute(sql) + row = cursor.fetchone() + + assert row[0] == 1 + assert row[1] == 3.14 + assert row[2] is True + assert row[3] == "hello" + assert row[4] == b"bytes" + assert row[4] == b"bytes" diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_errors.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_errors.py new file mode 100644 index 000000000000..5a7b39a0f2ea --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_errors.py @@ -0,0 +1,74 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for error handling in cursor.py and connection.py""" + +import pytest + +from google.cloud.spanner_driver import connect, errors + +from ._helper import get_test_connection_string + +DDL = """ +CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), +) PRIMARY KEY (SingerId) +""" + + +class TestErrors: + def setup_method(self): + """Re-create the table before each test.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS Singers") + cursor.execute(DDL) + + def test_programming_error_table_not_found(self): + """Test that selecting from a non-existent table + raises expected error.""" + connection_string = get_test_connection_string() + + with connect(connection_string) as connection: + with connection.cursor() as cursor: + with pytest.raises(errors.ProgrammingError): + cursor.execute("SELECT * FROM NonExistentTable") + + def test_integrity_error_duplicate_pk(self): + """Test that duplicate primary key raises IntegrityError.""" + connection_string = get_test_connection_string() + + with connect(connection_string) as connection: + with connection.cursor() as cursor: + sql = ( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)" + ) + params = {"id": 1, "first": "Alice", "last": "A"} + + cursor.execute(sql, params) + + # Second insert with same PK + with pytest.raises(errors.IntegrityError): + cursor.execute(sql, params) + + def test_operational_error_syntax(self): + """Test bad syntax raises ProgrammingError/OperationalError.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + with pytest.raises(errors.ProgrammingError): + cursor.execute("SELECT * FROM Singers WHERE") diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_executemany.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_executemany.py new file mode 100644 index 000000000000..ea305480db1e --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_executemany.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for executemany support in cursor.py""" + +from google.cloud.spanner_driver import connect + +from ._helper import get_test_connection_string + +DDL = """ +CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), +) PRIMARY KEY (SingerId) +""" + + +class TestExecuteMany: + def setup_method(self): + """Re-create the table before each test.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS Singers") + cursor.execute(DDL) + + def test_executemany(self): + """Test executemany with multiple rows.""" + connection_string = get_test_connection_string() + + with connect(connection_string) as connection: + with connection.cursor() as cursor: + sql = ( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)" + ) + params_seq = [ + {"id": 1, "first": "Alice", "last": "A"}, + {"id": 2, "first": "Bob", "last": "B"}, + {"id": 3, "first": "Charlie", "last": "C"}, + ] + + cursor.executemany(sql, params_seq) + + assert cursor.rowcount == 3 + + # Verify rows + cursor.execute("SELECT * FROM Singers ORDER BY SingerId") + rows = cursor.fetchall() + assert len(rows) == 3 + assert rows[0] == (1, "Alice", "A") + assert rows[1] == (2, "Bob", "B") + assert rows[2] == (3, "Charlie", "C") diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_transaction.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_transaction.py new file mode 100644 index 000000000000..0e3db27d9e25 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_transaction.py @@ -0,0 +1,116 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for transaction support in connection.py""" + +from google.cloud.spanner_driver import connect + +from ._helper import get_test_connection_string + +DDL = """ +CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), +) PRIMARY KEY (SingerId) +""" + + +class TestTransaction: + def setup_method(self): + """Re-create the table before each test.""" + connection_string = get_test_connection_string() + with connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS Singers") + cursor.execute(DDL) + + def test_commit(self): + """Test that changes are visible after commit.""" + connection_string = get_test_connection_string() + + # 1. Insert in a transaction + with connect(connection_string) as conn1: + conn1.begin() + with conn1.cursor() as cursor: + cursor.execute( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)", + {"id": 1, "first": "John", "last": "Doe"}, + ) + conn1.commit() + + # 2. Verify visibility from another connection + with connect(connection_string) as conn2: + with conn2.cursor() as cursor: + cursor.execute("SELECT FirstName FROM Singers WHERE SingerId = 1") + row = cursor.fetchone() + assert row == ("John",) + + def test_rollback(self): + """Test that changes are discarded after rollback.""" + connection_string = get_test_connection_string() + + # 1. Insert then rollback + with connect(connection_string) as conn1: + conn1.begin() + with conn1.cursor() as cursor: + cursor.execute( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)", + {"id": 2, "first": "Jane", "last": "Doe"}, + ) + conn1.rollback() + + # 2. Verify NOT visible + with connect(connection_string) as conn2: + with conn2.cursor() as cursor: + cursor.execute("SELECT FirstName FROM Singers WHERE SingerId = 2") + row = cursor.fetchone() + assert row is None + + def test_isolation(self): + """Test that uncommitted changes are not visible to others.""" + connection_string = get_test_connection_string() + + conn1 = connect(connection_string) + conn2 = connect(connection_string) + + try: + conn1.begin() + curs1 = conn1.cursor() + curs2 = conn2.cursor() + + # Insert in conn1 (uncommitted) + curs1.execute( + "INSERT INTO Singers (SingerId, FirstName, LastName) " + "VALUES (@id, @first, @last)", + {"id": 3, "first": "Bob", "last": "Smith"}, + ) + + # Check from conn2 + curs2.execute("SELECT FirstName FROM Singers WHERE SingerId = 3") + row = curs2.fetchone() + assert row is None + + # Commit conn1 + conn1.commit() + + # Check from conn2 + curs2.execute("SELECT FirstName FROM Singers WHERE SingerId = 3") + row = curs2.fetchone() + assert row == ("Bob",) + + finally: + conn1.close() + conn2.close() diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/conftest.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/conftest.py new file mode 100644 index 000000000000..42047ef13721 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/conftest.py @@ -0,0 +1,221 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import MagicMock + +import google.cloud + + +# 1. Define Exception Classes +class MockGoogleAPICallError(Exception): + def __init__(self, message=None, errors=None, response=None, **kwargs): + super().__init__(message) + self.message = message + self.errors = errors + self.response = response + self.reason = "reason" + self.domain = "domain" + self.metadata = {} + self.details = [] + + +class AlreadyExists(MockGoogleAPICallError): + pass + + +class NotFound(MockGoogleAPICallError): + pass + + +class InvalidArgument(MockGoogleAPICallError): + pass + + +class FailedPrecondition(MockGoogleAPICallError): + pass + + +class OutOfRange(MockGoogleAPICallError): + pass + + +class Unauthenticated(MockGoogleAPICallError): + pass + + +class PermissionDenied(MockGoogleAPICallError): + pass + + +class DeadlineExceeded(MockGoogleAPICallError): + pass + + +class ServiceUnavailable(MockGoogleAPICallError): + pass + + +class Aborted(MockGoogleAPICallError): + pass + + +class InternalServerError(MockGoogleAPICallError): + pass + + +class Unknown(MockGoogleAPICallError): + pass + + +class Cancelled(MockGoogleAPICallError): + pass + + +class DataLoss(MockGoogleAPICallError): + pass + + +class MockSpannerLibError(Exception): + pass + + +# 2. Define Type/Proto Classes +class MockTypeCode: + STRING = 1 + BYTES = 2 + BOOL = 3 + INT64 = 4 + FLOAT64 = 5 + DATE = 6 + TIMESTAMP = 7 + NUMERIC = 8 + JSON = 9 + PROTO = 10 + ENUM = 11 + + +class MockExecuteSqlRequest: + def __init__(self, sql=None, params=None): + self.sql = sql + self.params = params + + +class MockType: + def __init__(self, code): + self.code = code + + def __eq__(self, other): + return isinstance(other, MockType) and self.code == other.code + + def __repr__(self): + return f"MockType(code={self.code})" + + +class MockStructField: + def __init__(self, name, type_): + self.name = name + self.type = type_ # Avoid conflict with builtin type + + def __eq__(self, other): + return ( + isinstance(other, MockStructField) + and self.name == other.name + and self.type == other.type + ) + + +class MockStructType: + def __init__(self, fields): + self.fields = fields + + +# 3. Create Module Mocks +# google.cloud.spanner_v1 +spanner_v1 = MagicMock() +spanner_v1.TypeCode = MockTypeCode +spanner_v1.ExecuteSqlRequest = MockExecuteSqlRequest +spanner_v1.Type = MockType +spanner_v1.StructField = MockStructField +spanner_v1.StructType = MockStructType + +# google.cloud.spanner_v1.types +spanner_v1_types = MagicMock() +spanner_v1_types.Type = MockType +spanner_v1_types.StructField = MockStructField +spanner_v1_types.StructType = MockStructType + +# google.api_core.exceptions +exceptions_module = MagicMock() +exceptions_module.GoogleAPICallError = MockGoogleAPICallError +exceptions_module.AlreadyExists = AlreadyExists +exceptions_module.NotFound = NotFound +exceptions_module.InvalidArgument = InvalidArgument +exceptions_module.FailedPrecondition = FailedPrecondition +exceptions_module.OutOfRange = OutOfRange +exceptions_module.Unauthenticated = Unauthenticated +exceptions_module.PermissionDenied = PermissionDenied +exceptions_module.DeadlineExceeded = DeadlineExceeded +exceptions_module.ServiceUnavailable = ServiceUnavailable +exceptions_module.Aborted = Aborted +exceptions_module.InternalServerError = InternalServerError +exceptions_module.Unknown = Unknown +exceptions_module.Cancelled = Cancelled +exceptions_module.DataLoss = DataLoss + +# google.cloud.spannerlib +spannerlib = MagicMock() +# internal.errors +spannerlib_internal_errors = MagicMock() +spannerlib_internal_errors.SpannerLibError = MockSpannerLibError +spannerlib.internal.errors = spannerlib_internal_errors + +# pool +spannerlib_pool = MagicMock() +spannerlib.pool = spannerlib_pool + + +# pool.Pool class +class MockPool: + @staticmethod + def create_pool(connection_string): + return MockPool() + + def create_connection(self): + return MagicMock() + + +spannerlib.pool.Pool = MockPool + +# connection +spannerlib_connection = MagicMock() +spannerlib.connection = spannerlib_connection + +# 4. Inject into sys.modules +sys.modules["google.cloud.spanner_v1"] = spanner_v1 +sys.modules["google.cloud.spanner_v1.types"] = spanner_v1_types +sys.modules["google.api_core.exceptions"] = exceptions_module +sys.modules["google.api_core"] = MagicMock(exceptions=exceptions_module) +sys.modules["google.cloud.spannerlib"] = spannerlib +sys.modules["google.cloud.spannerlib.internal"] = spannerlib.internal +sys.modules["google.cloud.spannerlib.internal.errors"] = spannerlib_internal_errors +sys.modules["google.cloud.spannerlib.pool"] = spannerlib_pool +sys.modules["google.cloud.spannerlib.connection"] = spannerlib_connection + + +# 4. Patch google.cloud +# This is tricky because google is a namespace package +# but spannerlib might need to be explicitly set in google.cloud +google.cloud.spannerlib = spannerlib +google.cloud.spanner_v1 = spanner_v1 diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py new file mode 100644 index 000000000000..ed9a0fa18736 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py @@ -0,0 +1,111 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock + +from google.cloud import spanner_driver +from google.cloud.spanner_driver import connection, errors + + +class TestConnect(unittest.TestCase): + def test_connect(self): + connection_string = "spanner://projects/p/instances/i/databases/d" + + with mock.patch( + "google.cloud.spannerlib.pool.Pool.create_pool" + ) as mock_create_pool: + mock_pool = mock.Mock() + mock_create_pool.return_value = mock_pool + mock_internal_conn = mock.Mock() + mock_pool.create_connection.return_value = mock_internal_conn + + conn = spanner_driver.connect(connection_string) + + self.assertIsInstance(conn, connection.Connection) + mock_create_pool.assert_called_once_with(connection_string) + mock_pool.create_connection.assert_called_once() + + +class TestConnection(unittest.TestCase): + def setUp(self): + self.mock_internal_conn = mock.Mock() + self.conn = connection.Connection(self.mock_internal_conn) + + def test_cursor(self): + cursor = self.conn.cursor() + self.assertIsInstance(cursor, spanner_driver.Cursor) + self.assertEqual(cursor._connection, self.conn) + + def test_cursor_closed(self): + self.conn.close() + with self.assertRaises(errors.InterfaceError): + self.conn.cursor() + + def test_begin(self): + self.conn.begin() + self.mock_internal_conn.begin_transaction.assert_called_once() + + def test_begin_error(self): + self.mock_internal_conn.begin_transaction.side_effect = Exception( + "Internal Error" + ) + with self.assertRaises(errors.DatabaseError): + self.conn.begin() + + def test_commit(self): + self.conn.commit() + self.mock_internal_conn.commit.assert_called_once() + + def test_commit_error(self): + self.mock_internal_conn.commit.side_effect = Exception("Commit Failed") + try: + self.conn.commit() + except Exception: + self.fail("commit() raised Exception unexpectedly!") + self.mock_internal_conn.commit.assert_called_once() + + def test_rollback(self): + self.conn.rollback() + self.mock_internal_conn.rollback.assert_called_once() + + def test_rollback_error(self): + # Similar to commit, rollback errors are caught and logged + self.mock_internal_conn.rollback.side_effect = Exception("Rollback Failed") + try: + self.conn.rollback() + except Exception: + self.fail("rollback() raised Exception unexpectedly!") + self.mock_internal_conn.rollback.assert_called_once() + + def test_close(self): + self.assertFalse(self.conn._closed) + self.conn.close() + self.assertTrue(self.conn._closed) + self.mock_internal_conn.close.assert_called_once() + + def test_close_idempotent(self): + self.conn.close() + self.mock_internal_conn.close.reset_mock() + self.assertRaises(errors.InterfaceError, self.conn.close) + + def test_messages(self): + self.assertEqual(self.conn.messages, []) + + def test_context_manager(self): + with self.conn as c: + self.assertEqual(c, self.conn) + self.assertFalse(c._closed) + self.assertTrue(self.conn._closed) + self.mock_internal_conn.close.assert_called_once() diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py new file mode 100644 index 000000000000..7cb6cf4e992f --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py @@ -0,0 +1,349 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest +import uuid +from unittest import mock + +from google.cloud.spanner_v1 import ExecuteSqlRequest, TypeCode +from google.cloud.spanner_v1.types import StructField, Type + +from google.cloud.spanner_driver import cursor + + +class TestCursor(unittest.TestCase): + def setUp(self): + self.mock_connection = mock.Mock() + self.mock_internal_conn = mock.Mock() + self.mock_connection._internal_conn = self.mock_internal_conn + self.cursor = cursor.Cursor(self.mock_connection) + + def test_init(self): + self.assertEqual(self.cursor._connection, self.mock_connection) + + def test_execute(self): + operation = "SELECT * FROM table" + mock_rows = mock.Mock() + # Mocking description to be None so it treats as DML or query with no + # result initially? If description calls metadata(), we need to mock + # that. logic: if self.description: self._rowcount = -1 + + # Scenario 1: SELECT query (returns rows) + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + self.mock_internal_conn.execute.return_value = mock_rows + + self.cursor.execute(operation) + + self.mock_internal_conn.execute.assert_called_once() + call_args = self.mock_internal_conn.execute.call_args + self.assertIsInstance(call_args[0][0], ExecuteSqlRequest) + self.assertEqual(call_args[0][0].sql, operation) + self.assertEqual(self.cursor._rowcount, -1) + self.assertEqual(self.cursor._rows, mock_rows) + + def test_execute_dml(self): + operation = "UPDATE table SET col=1" + mock_rows = mock.Mock() + # Returns empty metadata or no metadata for DML? + # Actually in Spanner, DML returns a ResultSet with stats. + # But here we check `if self.description`. + + # Scenario 2: DML (no fields in metadata usually, or we can simulate + # it) If metadata calls fail or return empty, description returns + # usually None. + mock_rows.metadata.return_value = None + mock_rows.update_count.return_value = 10 + self.mock_internal_conn.execute.return_value = mock_rows + + self.cursor.execute(operation) + + self.assertEqual(self.cursor._rowcount, 10) + # rows should be closed and set to None for DML in this driver + # implementation + mock_rows.close.assert_called_once() + self.assertIsNone(self.cursor._rows) + + def test_execute_with_params(self): + operation = "SELECT * FROM table WHERE id=@id" + params = {"id": 1} + mock_rows = mock.Mock() + mock_rows.metadata.return_value = mock.Mock() + self.mock_internal_conn.execute.return_value = mock_rows + + self.cursor.execute(operation, params) + + call_args = self.mock_internal_conn.execute.call_args + request = call_args[0][0] + self.assertEqual(request.sql, operation) + self.assertEqual(request.sql, operation) + self.assertEqual(request.params, {"id": "1"}) + + def test_executemany(self): + operation = "INSERT INTO table (id) VALUES (@id)" + params_seq = [{"id": 1, "name": "val1"}, {"id": 2}] + + # Mock execute_batch response + mock_response = mock.Mock() + mock_result_set1 = mock.Mock() + mock_result_set1.stats.row_count_exact = 1 + mock_result_set2 = mock.Mock() + mock_result_set2.stats.row_count_exact = 1 + mock_response.result_sets = [mock_result_set1, mock_result_set2] + + self.mock_internal_conn.execute_batch.return_value = mock_response + + # Patch ExecuteBatchDmlRequest in cursor module + with mock.patch( + "google.cloud.spanner_driver.cursor.ExecuteBatchDmlRequest" + ) as MockRequest: + # Setup mock request instance and statements list behavior + mock_request_instance = MockRequest.return_value + mock_request_instance.statements = [] # Use a real list to verify append + + # Setup Statement mock + MockStatement = mock.Mock() + MockRequest.Statement = MockStatement + + self.cursor.executemany(operation, params_seq) + + # Verify execute_batch called with our mock request + self.mock_internal_conn.execute_batch.assert_called_once_with( + mock_request_instance + ) + + # Verify statements were created and appended + self.assertEqual(len(mock_request_instance.statements), 2) + + # Verify first statement + call1 = MockStatement.call_args_list[0] + self.assertEqual(call1.kwargs["sql"], operation) + + self.assertEqual(MockStatement.call_count, 2) + + # Verify rowcount update + self.assertEqual(self.cursor.rowcount, 2) + + def test_fetchone(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + # Mock metadata for type information + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + mock_rows.metadata.return_value = mock_metadata + + # Mock row as object with values attribute + mock_row = mock.Mock() + mock_val = mock.Mock() + mock_val.WhichOneof.return_value = "string_value" + mock_val.string_value = "1" + mock_row.values = [mock_val] + + mock_rows.next.return_value = mock_row + + row = self.cursor.fetchone() + self.assertEqual(row, (1,)) + mock_rows.next.assert_called_once() + + def test_fetchone_empty(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + mock_rows.next.side_effect = StopIteration + + row = self.cursor.fetchone() + self.assertIsNone(row) + + def test_fetchmany(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + # Metadata + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + mock_rows.metadata.return_value = mock_metadata + + # Rows + mock_row1 = mock.Mock() + v1 = mock.Mock() + v1.WhichOneof.return_value = "string_value" + v1.string_value = "1" + mock_row1.values = [v1] + + mock_row2 = mock.Mock() + v2 = mock.Mock() + v2.WhichOneof.return_value = "string_value" + v2.string_value = "2" + mock_row2.values = [v2] + + mock_rows.next.side_effect = [mock_row1, mock_row2, StopIteration] + + rows = self.cursor.fetchmany(size=5) + self.assertEqual(len(rows), 2) + self.assertEqual(rows, [(1,), (2,)]) + + def test_fetchall(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + # Metadata + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + mock_rows.metadata.return_value = mock_metadata + + # Rows + mock_row1 = mock.Mock() + v1 = mock.Mock() + v1.WhichOneof.return_value = "string_value" + v1.string_value = "1" + mock_row1.values = [v1] + + mock_row2 = mock.Mock() + v2 = mock.Mock() + v2.WhichOneof.return_value = "string_value" + v2.string_value = "2" + mock_row2.values = [v2] + + mock_rows.next.side_effect = [mock_row1, mock_row2, StopIteration] + + rows = self.cursor.fetchall() + self.assertEqual(len(rows), 2) + + def test_description(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)), + StructField(name="col2", type_=Type(code=TypeCode.STRING)), + ] + mock_rows.metadata.return_value = mock_metadata + + desc = self.cursor.description + self.assertEqual(len(desc), 2) + self.assertEqual(desc[0][0], "col1") + self.assertEqual(desc[1][0], "col2") + + def test_close(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + self.cursor.close() + + self.assertTrue(self.cursor._closed) + mock_rows.close.assert_called_once() + + def test_context_manager(self): + with self.cursor as c: + self.assertEqual(c, self.cursor) + self.assertTrue(self.cursor._closed) + + def test_iterator(self): + mock_rows = mock.Mock() + self.cursor._rows = mock_rows + + mock_metadata = mock.Mock() + mock_metadata.row_type.fields = [ + StructField(name="col1", type_=Type(code=TypeCode.INT64)) + ] + mock_rows.metadata.return_value = mock_metadata + mock_rows.metadata.return_value = mock_metadata + + mock_row = mock.Mock() + v1 = mock.Mock() + v1.WhichOneof.return_value = "string_value" + v1.string_value = "1" + mock_row.values = [v1] + + mock_rows.next.side_effect = [mock_row, StopIteration] + + # __next__ calls fetchone + it = iter(self.cursor) + self.assertEqual(next(it), (1,)) + with self.assertRaises(StopIteration): + next(it) + + def test_prepare_params(self): + # Test 1: None + converted, types = self.cursor._prepare_params(None) + self.assertEqual(converted, {}) + self.assertEqual(types, {}) + + # Test 2: Dict (GoogleSQL) + uuid_val = uuid.uuid4() + dt_val = datetime.datetime(2024, 1, 1, 12, 0, 0) + date_val = datetime.date(2024, 1, 1) + params = { + "int_val": 123, + "bool_val": True, + "float_val": 1.23, + "bytes_val": b"bytes", + "str_val": "string", + "uuid_val": uuid_val, + "dt_val": dt_val, + "date_val": date_val, + "none_val": None, + } + converted, types = self.cursor._prepare_params(params) + + self.assertEqual(converted["int_val"], "123") + self.assertEqual(types["int_val"].code, TypeCode.INT64) + + self.assertEqual(converted["bool_val"], True) + self.assertEqual(types["bool_val"].code, TypeCode.BOOL) + + self.assertEqual(converted["float_val"], 1.23) + self.assertEqual(types["float_val"].code, TypeCode.FLOAT64) + + self.assertEqual(converted["bytes_val"], b"bytes") + self.assertEqual(types["bytes_val"].code, TypeCode.BYTES) + + self.assertEqual(converted["str_val"], "string") + self.assertEqual(types["str_val"].code, TypeCode.STRING) + + self.assertEqual(converted["uuid_val"], str(uuid_val)) + self.assertEqual(types["uuid_val"].code, TypeCode.STRING) + + self.assertEqual(converted["dt_val"], str(dt_val)) + self.assertEqual(types["dt_val"].code, TypeCode.TIMESTAMP) + + self.assertEqual(converted["date_val"], str(date_val)) + self.assertEqual(types["date_val"].code, TypeCode.DATE) + + self.assertIsNone(converted["none_val"]) + self.assertNotIn("none_val", types) + + # Test 3: List (PostgreSQL) + params_list = [1, "test"] + converted, types = self.cursor._prepare_params(params_list) + + self.assertEqual(converted["P1"], "1") + self.assertEqual(types["P1"].code, TypeCode.INT64) + + self.assertEqual(converted["P2"], "test") + self.assertEqual(types["P2"].code, TypeCode.STRING) diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_errors.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_errors.py new file mode 100644 index 000000000000..deabcaca6d79 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_errors.py @@ -0,0 +1,57 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from google.api_core import exceptions +from google.cloud.spannerlib.internal.errors import SpannerLibError + +from google.cloud.spanner_driver import errors + + +class TestErrors(unittest.TestCase): + def test_map_spanner_lib_error(self): + err = SpannerLibError("Internal Error") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.DatabaseError) + + def test_map_not_found(self): + err = exceptions.NotFound("Not found") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.ProgrammingError) + + def test_map_already_exists(self): + err = exceptions.AlreadyExists("Exists") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.IntegrityError) + + def test_map_invalid_argument(self): + err = exceptions.InvalidArgument("Invalid") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.ProgrammingError) + + def test_map_failed_precondition(self): + err = exceptions.FailedPrecondition("Precondition") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.OperationalError) + + def test_map_out_of_range(self): + err = exceptions.OutOfRange("OOR") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.DataError) + + def test_map_unknown(self): + err = exceptions.Unknown("Unknown") + mapped_err = errors.map_spanner_error(err) + self.assertIsInstance(mapped_err, errors.DatabaseError) diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_types.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_types.py new file mode 100644 index 000000000000..4dd3b45f11ed --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_types.py @@ -0,0 +1,57 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest + +from google.cloud.spanner_v1 import TypeCode + +from google.cloud.spanner_driver import types + + +class TestTypes(unittest.TestCase): + def test_date(self): + d = types.Date(2025, 1, 1) + self.assertEqual(d, datetime.date(2025, 1, 1)) + + def test_time(self): + t = types.Time(12, 30, 0) + self.assertEqual(t, datetime.time(12, 30, 0)) + + def test_timestamp(self): + ts = types.Timestamp(2025, 1, 1, 12, 30, 0) + self.assertEqual(ts, datetime.datetime(2025, 1, 1, 12, 30, 0)) + + def test_binary(self): + b = types.Binary("hello") + self.assertEqual(b, b"hello") + b2 = types.Binary(b"world") + self.assertEqual(b2, b"world") + + def test_type_objects(self): + self.assertEqual(types.STRING, types.STRING) + self.assertNotEqual(types.STRING, types.NUMBER) + self.assertEqual(types.STRING, "STRING") # DBAPITypeObject compares using 'in' + + def test_type_code_mapping(self): + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.STRING), types.STRING) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.INT64), types.NUMBER) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.BOOL), types.BOOLEAN) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.FLOAT64), types.NUMBER) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.BYTES), types.BINARY) + self.assertEqual( + types._type_code_to_dbapi_type(TypeCode.TIMESTAMP), types.DATETIME + ) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.DATE), types.DATETIME) + self.assertEqual(types._type_code_to_dbapi_type(TypeCode.JSON), types.STRING) From 5f7001d4e9ef1204a48ffe62f015b0936ae929e5 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Wed, 25 Mar 2026 12:29:47 +0000 Subject: [PATCH 05/12] refactor: improve error handling, parameter validation, and logging in Spanner DBAPI driver --- .../google/cloud/spanner_driver/__init__.py | 42 +++++++++---------- .../google/cloud/spanner_driver/connection.py | 6 +-- .../google/cloud/spanner_driver/cursor.py | 11 +++-- .../tests/system/test_cursor.py | 1 - .../tests/unit/test_connection.py | 9 +--- .../tests/unit/test_cursor.py | 4 ++ 6 files changed, 37 insertions(+), 36 deletions(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py index d898b418c6f5..32cac8125778 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py @@ -52,32 +52,32 @@ logger.addHandler(logging.NullHandler()) __all__: list[str] = [ - "apilevel", - "threadsafety", - "paramstyle", + "BINARY", + "Binary", "Connection", - "connect", "Cursor", - "Date", - "Time", - "Timestamp", - "DateFromTicks", - "TimeFromTicks", - "TimestampFromTicks", - "Binary", - "STRING", - "BINARY", - "NUMBER", "DATETIME", - "ROWID", - "InterfaceError", - "ProgrammingError", - "OperationalError", - "DatabaseError", "DataError", - "NotSupportedError", + "DatabaseError", + "Date", + "DateFromTicks", + "Error", "IntegrityError", + "InterfaceError", "InternalError", + "NUMBER", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "ROWID", + "STRING", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", "Warning", - "Error", + "apilevel", + "connect", + "paramstyle", + "threadsafety", ] diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py index 12e4c3638d98..e81cefd2497f 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py @@ -89,8 +89,8 @@ def commit(self) -> None: try: self._internal_conn.commit() except Exception as e: - # raise errors.map_spanner_error(e) logger.debug(f"Commit failed {e}") + raise errors.map_spanner_error(e) @check_not_closed def rollback(self) -> None: @@ -102,8 +102,8 @@ def rollback(self) -> None: try: self._internal_conn.rollback() except Exception as e: - # raise errors.map_spanner_error(e) logger.debug(f"Rollback failed {e}") + raise errors.map_spanner_error(e) def close(self) -> None: """Close the connection now. @@ -127,7 +127,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() -def connect(connection_string: str, **kwargs: Any) -> Connection: +def connect(connection_string: str) -> Connection: logger.debug(f"Connecting to {connection_string}") # Create the pool pool = Pool.create_pool(connection_string) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py index a81e95ef47e8..4278ef791d17 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py @@ -116,7 +116,8 @@ def description(self) -> tuple[tuple[Any, ...], ...] | None: ) ) return tuple(desc) - except Exception: + except Exception as e: + logger.warning("Could not determine cursor description: %s", e) return None @property @@ -165,8 +166,9 @@ def _prepare_params( # GoogleSQL Dialect: Named parameters @name are mapped directly. iterator = parameters.items() else: - # If strictly required, raise an error for unsupported types - return {}, {} + raise errors.ProgrammingError( + f"Parameters must be a dict, list, or tuple, not {type(parameters).__name__}" + ) for key, value in iterator: if value is None: @@ -429,7 +431,8 @@ def nextset(self) -> bool | None: if next_metadata: return True return None - except Exception: + except Exception as e: + logger.warning("Could not determine next set of results: %s", e) return None def __enter__(self) -> "Cursor": diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py index 5719b4030fa5..5287fc646008 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/system/test_cursor.py @@ -141,4 +141,3 @@ def test_data_types(self): assert row[2] is True assert row[3] == "hello" assert row[4] == b"bytes" - assert row[4] == b"bytes" diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py index ed9a0fa18736..56feea8792ea 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_connection.py @@ -70,10 +70,8 @@ def test_commit(self): def test_commit_error(self): self.mock_internal_conn.commit.side_effect = Exception("Commit Failed") - try: + with self.assertRaises(errors.DatabaseError): self.conn.commit() - except Exception: - self.fail("commit() raised Exception unexpectedly!") self.mock_internal_conn.commit.assert_called_once() def test_rollback(self): @@ -81,12 +79,9 @@ def test_rollback(self): self.mock_internal_conn.rollback.assert_called_once() def test_rollback_error(self): - # Similar to commit, rollback errors are caught and logged self.mock_internal_conn.rollback.side_effect = Exception("Rollback Failed") - try: + with self.assertRaises(errors.DatabaseError): self.conn.rollback() - except Exception: - self.fail("rollback() raised Exception unexpectedly!") self.mock_internal_conn.rollback.assert_called_once() def test_close(self): diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py index 7cb6cf4e992f..9042a66c0645 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/unit/test_cursor.py @@ -347,3 +347,7 @@ def test_prepare_params(self): self.assertEqual(converted["P2"], "test") self.assertEqual(types["P2"].code, TypeCode.STRING) + + def test_prepare_params_unsupported_type(self): + with self.assertRaises(cursor.errors.ProgrammingError): + self.cursor._prepare_params(123) # Int is not supported directly From e7279581b52252842cb4388ace2239eb6ffc7b23 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Tue, 14 Apr 2026 10:06:29 +0000 Subject: [PATCH 06/12] chore: import version string from version module instead of hardcoding it --- .../google/cloud/spanner_driver/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py index 32cac8125778..751086e45369 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/__init__.py @@ -45,8 +45,9 @@ Timestamp, TimestampFromTicks, ) +from .version import __version__ as _version -__version__: Final[str] = "0.0.1" +__version__: Final[str] = _version logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) From 6e36c69dccecc6512c911caa68026db42bba98a2 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Wed, 15 Apr 2026 05:57:37 +0000 Subject: [PATCH 07/12] refactor: improve type hinting and resolve static analysis errors in cursor, errors, and connection modules --- .../google/cloud/spanner_driver/connection.py | 2 +- .../google/cloud/spanner_driver/cursor.py | 11 ++++++----- .../google/cloud/spanner_driver/errors.py | 12 +++++++----- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py index e81cefd2497f..37aa5b2fb988 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/connection.py @@ -14,7 +14,7 @@ import logging from typing import Any -from google.cloud.spannerlib.pool import Pool +from google.cloud.spannerlib.pool import Pool # type: ignore[import-untyped] from . import errors from .cursor import Cursor diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py index 4278ef791d17..9798c53abbd7 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/cursor.py @@ -16,7 +16,7 @@ import logging import uuid from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterable from google.cloud.spanner_v1 import ( ExecuteBatchDmlRequest, @@ -135,7 +135,7 @@ def rowcount(self) -> int: def _prepare_params( self, parameters: dict[str, Any] | list[Any] | tuple[Any] | None = None - ) -> (dict[str, Any] | None, dict[str, Type] | None): + ) -> tuple[dict[str, Any] | None, dict[str, Type] | None]: """ Prepares parameters for Spanner execution @@ -154,10 +154,11 @@ def _prepare_params( if not parameters: return {}, {} - converted_params = {} + converted_params: dict[str, Any] = {} param_types = {} # Normalize input to an iterable of (key, value) + iterator: Iterable[tuple[str, Any]] if isinstance(parameters, (list, tuple)): # PostgreSQL Dialect: Positional parameters $1, $2... are # mapped to P1, P2... @@ -233,7 +234,7 @@ def execute( request = ExecuteSqlRequest(sql=operation) params, _ = self._prepare_params(parameters) - request.params = params + request.params = params # type: ignore[assignment] try: self._rows = self._connection._internal_conn.execute(request) @@ -271,7 +272,7 @@ def executemany( for parameters in seq_of_parameters: statement = ExecuteBatchDmlRequest.Statement(sql=operation) params, _ = self._prepare_params(parameters) - statement.params = params + statement.params = params # type: ignore[assignment] request.statements.append(statement) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py index 8225d374eee8..41c3fe92c943 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py @@ -60,7 +60,7 @@ def reason(self) -> str | None: """ return ( self.__cause__.reason - if self._is_error_cause_instance_of_google_api_exception() + if isinstance(self.__cause__, GoogleAPICallError) else None ) @@ -75,7 +75,7 @@ def domain(self) -> str | None: """ return ( self.__cause__.domain - if self._is_error_cause_instance_of_google_api_exception() + if isinstance(self.__cause__, GoogleAPICallError) else None ) @@ -90,7 +90,7 @@ def metadata(self) -> dict[str, str] | None: """ return ( self.__cause__.metadata - if self._is_error_cause_instance_of_google_api_exception() + if isinstance(self.__cause__, GoogleAPICallError) else None ) @@ -106,7 +106,7 @@ def details(self) -> Sequence[Any] | None: """ return ( self.__cause__.details - if self._is_error_cause_instance_of_google_api_exception() + if isinstance(self.__cause__, GoogleAPICallError) else None ) @@ -186,7 +186,9 @@ class NotSupportedError(DatabaseError): def map_spanner_error(error: Exception) -> Error: """Map SpannerLibError or GoogleAPICallError to DB API 2.0 errors.""" from google.api_core import exceptions - from google.cloud.spannerlib.internal.errors import SpannerLibError + from google.cloud.spannerlib.internal.errors import ( + SpannerLibError, # type: ignore[import-untyped] + ) match error: # Handle SpannerLibError by matching on the internal From ea774cece18c1965989d3b2fb08ddd5ebe7ed91b Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Wed, 15 Apr 2026 06:05:43 +0000 Subject: [PATCH 08/12] refactor: reformat SpannerLibError import for consistent style --- .../google/cloud/spanner_driver/errors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py index 41c3fe92c943..1c63a3ebe641 100644 --- a/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py +++ b/packages/google-cloud-spanner-dbapi-driver/google/cloud/spanner_driver/errors.py @@ -186,8 +186,8 @@ class NotSupportedError(DatabaseError): def map_spanner_error(error: Exception) -> Error: """Map SpannerLibError or GoogleAPICallError to DB API 2.0 errors.""" from google.api_core import exceptions - from google.cloud.spannerlib.internal.errors import ( - SpannerLibError, # type: ignore[import-untyped] + from google.cloud.spannerlib.internal.errors import ( # type: ignore[import-untyped] + SpannerLibError, ) match error: From 768868f8fa406803598f162da5042b9511657d04 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Wed, 18 Mar 2026 08:16:06 +0000 Subject: [PATCH 09/12] feat: initial scaffolding for the `google-cloud-spanner-dbapi-driver` package, including core files, tests, documentation, and build configurations. --- packages/google-cloud-spanner-dbapi-driver/README.rst | 2 +- packages/google-cloud-spanner-dbapi-driver/docs/README.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/README.rst b/packages/google-cloud-spanner-dbapi-driver/README.rst index 859400cc6da4..29d7be7da11e 100644 --- a/packages/google-cloud-spanner-dbapi-driver/README.rst +++ b/packages/google-cloud-spanner-dbapi-driver/README.rst @@ -3,7 +3,7 @@ Python DBAPI 2.0 Compliant Driver for Spanner |stable| |pypi| |versions| -.. |stable| image:: https://img.shields.io/badge/support-preview-orange.svg +.. |stable| image:: https://img.shields.io/badge/support-stable-gold.svg :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#stability-levels .. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-spanner-dbapi-driver.svg :target: https://pypi.org/project/google-cloud-spanner-dbapi-driver/ diff --git a/packages/google-cloud-spanner-dbapi-driver/docs/README.rst b/packages/google-cloud-spanner-dbapi-driver/docs/README.rst index 859400cc6da4..29d7be7da11e 100644 --- a/packages/google-cloud-spanner-dbapi-driver/docs/README.rst +++ b/packages/google-cloud-spanner-dbapi-driver/docs/README.rst @@ -3,7 +3,7 @@ Python DBAPI 2.0 Compliant Driver for Spanner |stable| |pypi| |versions| -.. |stable| image:: https://img.shields.io/badge/support-preview-orange.svg +.. |stable| image:: https://img.shields.io/badge/support-stable-gold.svg :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#stability-levels .. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-spanner-dbapi-driver.svg :target: https://pypi.org/project/google-cloud-spanner-dbapi-driver/ From dedee9f41d5d811c52f6ba23013fde35cd62f4bd Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Tue, 24 Mar 2026 09:04:10 +0000 Subject: [PATCH 10/12] test: Add DBAPI 2.0 compliance test suite including a base test class, SQL factory, helper utilities, and nox integration. --- .../noxfile.py | 31 + .../tests/compliance/__init__.py | 15 + .../tests/compliance/_helper.py | 39 + .../compliance/dbapi20_compliance_testbase.py | 1040 +++++++++++++++++ .../tests/compliance/sql_factory.py | 204 ++++ .../tests/compliance/test_compliance.py | 43 + 6 files changed, 1372 insertions(+) create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/compliance/__init__.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/compliance/_helper.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/compliance/sql_factory.py create mode 100644 packages/google-cloud-spanner-dbapi-driver/tests/compliance/test_compliance.py diff --git a/packages/google-cloud-spanner-dbapi-driver/noxfile.py b/packages/google-cloud-spanner-dbapi-driver/noxfile.py index 2fedee7ee5af..c26d2f0f6a2d 100644 --- a/packages/google-cloud-spanner-dbapi-driver/noxfile.py +++ b/packages/google-cloud-spanner-dbapi-driver/noxfile.py @@ -84,6 +84,12 @@ SYSTEM_TEST_EXTRAS: List[str] = [] SYSTEM_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} +COMPLIANCE_TEST_STANDARD_DEPENDENCIES = [ + "pytest", + "spannerlib-python", + "google-cloud-spanner", +] + VERBOSE = False MODE = "--verbose" if VERBOSE else "--quiet" @@ -337,6 +343,31 @@ def system(session): ) +@nox.session(python=DEFAULT_PYTHON_VERSION) +def compliance(session): + """Run compliance tests.""" + + # Sanity check: Only run tests if the environment variable is set. + if not os.environ.get("SPANNER_EMULATOR_HOST", ""): + session.skip( + "Emulator host must be set via SPANNER_EMULATOR_HOST environment variable" + ) + + session.install(*COMPLIANCE_TEST_STANDARD_DEPENDENCIES) + session.install("-e", ".") + + test_paths = ( + session.posargs if session.posargs else [os.path.join("tests", "compliance")] + ) + session.run( + "py.test", + MODE, + f"--junitxml=compliance_{session.python}_sponge_log.xml", + *test_paths, + env={}, + ) + + @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/__init__.py b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/__init__.py new file mode 100644 index 000000000000..aeaeaa42f43f --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is intentionally left blank to mark this directory as a package. diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/_helper.py b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/_helper.py new file mode 100644 index 000000000000..45968fa9d1fa --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/_helper.py @@ -0,0 +1,39 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for compliance tests.""" + +import os + +SPANNER_EMULATOR_HOST = os.environ.get("SPANNER_EMULATOR_HOST") + +PROJECT_ID = "test-project" +INSTANCE_ID = "test-instance" +DATABASE_ID = "test-db" + +EMULATOR_TEST_CONNECTION_STRING = ( + f"{SPANNER_EMULATOR_HOST}" + f"projects/{PROJECT_ID}" + f"/instances/{INSTANCE_ID}" + f"/databases/{DATABASE_ID}" + "?autoConfigEmulator=true" +) + + +def setup_test_env() -> None: + print(f"Set SPANNER_EMULATOR_HOST to {os.environ['SPANNER_EMULATOR_HOST']}") + print(f"Using Connection String: {get_test_connection_string()}") + + +def get_test_connection_string() -> str: + return EMULATOR_TEST_CONNECTION_STRING diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py new file mode 100644 index 000000000000..390ec0a2adf2 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py @@ -0,0 +1,1040 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DBAPI 2.0 Compliance Test +""" + +import time +import unittest +from unittest.mock import MagicMock + +from .sql_factory import SQLFactory + + +def encode(s: str) -> bytes: + return s.encode("utf-8") + + +def decode(b: bytes) -> str: + return b.decode("utf-8") + + +class DBAPI20ComplianceTestBase(unittest.TestCase): + """ + Base class for DBAPI 2.0 Compliance Tests. + See PEP 249 for details: https://peps.python.org/pep-0249/ + """ + + __test__ = False + driver = None + errors = None + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + dialect = "GoogleSQL" + + lower_func = "lower" # Name of stored procedure to convert string->lowercase + + @property + def sql_factory(self): + return SQLFactory.get_factory(self.dialect) + + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls): + pass + + def setUp(self): + self.cleanup() + + def tearDown(self): + self.cleanup() + + def cleanup(self): + try: + con = self._connect() + try: + cur = con.cursor() + for ddl in self.sql_factory.stmt_ddl_drop_all_cmds: + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + except Exception: + pass + + def _connect(self): + try: + r = self.driver.connect(*self.connect_args, **self.connect_kw_args) + except AttributeError: + self.fail("No connect method found in self.driver module") + return r + + def _execute_select1(self, cur): + cur.execute(self.sql_factory.stmt_dql_select_1) + + def _simple_queries(self, cur): + # DDL + cur.execute(self.sql_factory.stmt_ddl_create_table1) + # DML + for sql in self.sql_factory.populate_table1(): + cur.execute(sql) + # DQL + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + _ = cur.fetchall() + self.assertTrue(cur.rowcount in (-1, len(self.sql_factory.names_table1))) + + def _parametized_queries(self, cur): + # DDL + cur.execute(self.sql_factory.stmt_ddl_create_table2) + # DML + cur.execute( + self.sql_factory.stmt_dml_insert_table2("101, 'Moms Lasagna', 1, True, ''") + ) + self.assertTrue(cur.rowcount in (-1, 1)) + + if self.driver.paramstyle == "qmark": + cur.execute( + self.sql_factory.stmt_dml_insert_table2( + "102, ?, 1, True, 'thi%%s :may ca%%(u)se? troub:1e'" + ), + ("Chocolate Brownie",), + ) + elif self.driver.paramstyle == "numeric": + cur.execute( + self.sql_factory.stmt_dml_insert_table2( + "102, :1, 1, True,'thi%%s :may ca%%(u)se? troub:1e'" + ), + ("Chocolate Brownie",), + ) + elif self.driver.paramstyle == "named": + cur.execute( + self.sql_factory.stmt_dml_insert_table2( + "102, :item_name, 1, True, 'thi%%s :may ca%%(u)se? troub:1e'" + ), + {"item_name": "Chocolate Brownie"}, + ) + elif self.driver.paramstyle == "format": + cur.execute( + self.sql_factory.stmt_dml_insert_table2( + "102, %%s, 1, True, 'thi%%%%s :may ca%%%%(u)se? troub:1e'" + ), + ("Chocolate Brownie",), + ) + elif self.driver.paramstyle == "pyformat": + cur.execute( + self.sql_factory.stmt_dml_insert_table2( + "102, %%(item_name), 1, True, 'thi%%%%s :may ca%%%%(u)se? troub:1e'" + ), + {"item_name": "Chocolate Brownie"}, + ) + else: + self.fail("Invalid paramstyle") + + self.assertTrue(cur.rowcount in (-1, 1)) + + # DQL + cur.execute(self.sql_factory.stmt_dql_select_all_table2()) + rows = cur.fetchall() + + self.assertEqual(len(rows), 2, "cursor.fetchall returned too few rows") + item_name = [rows[0][1], rows[1][1]] + item_name.sort() + self.assertEqual( + item_name[0], + "Chocolate Brownie", + "cursor.fetchall retrieved incorrect data, or data inserted incorrectly", + ) + self.assertEqual( + item_name[1], + "Moms Lasagna", + "cursor.fetchall retrieved incorrect data, or data inserted incorrectly", + ) + + trouble = "thi%s :may ca%(u)se? troub:1e" + self.assertEqual( + rows[0][4], + trouble, + "cursor.fetchall retrieved incorrect data, or data inserted " + "incorrectly. Got=%s, Expected=%s" % (repr(rows[0][4]), repr(trouble)), + ) + self.assertEqual( + rows[1][4], + trouble, + "cursor.fetchall retrieved incorrect data, or data inserted " + "incorrectly. Got=%s, Expected=%s" % (repr(rows[1][4]), repr(trouble)), + ) + + # ========================================================================= + # Module Interface + # ========================================================================= + + def test_module_attributes(self): + """Test module-level attributes. + See PEP 249 Module Interface. + """ + self.assertTrue(hasattr(self.driver, "apilevel")) + self.assertTrue(hasattr(self.driver, "threadsafety")) + self.assertTrue(hasattr(self.driver, "paramstyle")) + self.assertTrue(hasattr(self.driver, "connect")) + + def test_apilevel(self): + """Test module.apilevel. + Must be '2.0'. + """ + try: + apilevel = self.driver.apilevel + self.assertEqual(apilevel, "2.0", "Driver apilevel must be '2.0'") + except AttributeError: + self.fail("Driver doesn't define apilevel") + + def test_threadsafety(self): + """Test module.threadsafety. + Must be 0, 1, 2, or 3. + """ + try: + threadsafety = self.driver.threadsafety + self.assertTrue( + threadsafety in (0, 1, 2, 3), + "threadsafety must be one of 0, 1, 2, 3", + ) + except AttributeError: + self.fail("Driver doesn't define threadsafety") + + def test_paramstyle(self): + """Test module.paramstyle. + Must be one of 'qmark', 'numeric', 'named', 'format', 'pyformat'. + """ + try: + paramstyle = self.driver.paramstyle + self.assertTrue( + paramstyle in ("qmark", "numeric", "named", "format", "pyformat"), + "Invalid paramstyle", + ) + except AttributeError: + self.fail("Driver doesn't define paramstyle") + + def test_exceptions(self): + """Test module exception hierarchy. + See PEP 249 Exceptions. + """ + self.assertTrue(issubclass(self.errors.Warning, Exception)) + self.assertTrue(issubclass(self.errors.Error, Exception)) + self.assertTrue(issubclass(self.errors.InterfaceError, self.errors.Error)) + self.assertTrue(issubclass(self.errors.DatabaseError, self.errors.Error)) + self.assertTrue(issubclass(self.errors.DataError, self.errors.DatabaseError)) + self.assertTrue( + issubclass(self.errors.OperationalError, self.errors.DatabaseError) + ) + self.assertTrue( + issubclass(self.errors.IntegrityError, self.errors.DatabaseError) + ) + self.assertTrue( + issubclass(self.errors.InternalError, self.errors.DatabaseError) + ) + self.assertTrue( + issubclass(self.errors.ProgrammingError, self.errors.DatabaseError) + ) + self.assertTrue( + issubclass(self.errors.NotSupportedError, self.errors.DatabaseError) + ) + + # ========================================================================= + # Connection Objects + # ========================================================================= + + def test_connect(self): + """Test that connect returns a connection object.""" + conn = self._connect() + conn.close() + + def test_connection_attributes(self): + """Test Connection object attributes/methods.""" + # Mock connection internal + mock_internal = MagicMock() + conn = self.driver.Connection(mock_internal) + + self.assertTrue(hasattr(conn, "close")) + self.assertTrue(hasattr(conn, "commit")) + self.assertTrue(hasattr(conn, "rollback")) + self.assertTrue(hasattr(conn, "cursor")) + # Optional but checked because we added it + self.assertTrue(hasattr(conn, "messages")) + + def test_close(self): + """Test connection.close().""" + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + # cursor.execute should raise an Error if called + # after connection closed + self.assertRaises(self.driver.Error, self._execute_select1, cur) + + # connection.commit should raise an Error if called + # after connection closed + self.assertRaises(self.driver.Error, con.commit) + + def test_non_idempotent_close(self): + """Test that calling close() twice raises an Error + (optional behavior).""" + con = self._connect() + con.close() + # connection.close should raise an Error if called more than once + self.assertRaises(self.driver.Error, con.close) + + def test_commit(self): + """Test connection.commit().""" + con = self._connect() + try: + # Commit must work, even if it doesn't do anything + con.commit() + finally: + con.close() + + def test_rollback(self): + """Test connection.rollback().""" + con = self._connect() + try: + # If rollback is defined, it should either work or throw + # the documented exception + if hasattr(con, "rollback"): + try: + con.rollback() + except self.driver.NotSupportedError: + pass + finally: + con.close() + + def test_cursor(self): + """Test connection.cursor().""" + con = self._connect() + try: + curr = con.cursor() + self.assertIsNotNone(curr) + finally: + con.close() + + def test_cursor_isolation(self): + """Test that cursors are isolated (transactionally).""" + con = self._connect() + try: + # Make sure cursors created from the same connection have + # the documented transaction isolation level + cur1 = con.cursor() + cur2 = con.cursor() + cur1.execute(self.sql_factory.stmt_ddl_create_table1) + # DDL usually requires a clean slate or commit in some test envs + con.commit() + cur1.execute( + self.sql_factory.stmt_dml_insert_table1("1, 'Innocent Alice', 100") + ) + con.commit() + cur2.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + users = cur2.fetchone() + + self.assertEqual(len(users), 1) + self.assertEqual(users[0], "Innocent Alice") + finally: + con.close() + + # ========================================================================= + # Cursor Objects + # ========================================================================= + + def test_cursor_attributes(self): + """Test Cursor object attributes/methods.""" + mock_conn = MagicMock() + cursor = self.driver.Cursor(mock_conn) + + self.assertTrue(hasattr(cursor, "description")) + self.assertTrue(hasattr(cursor, "rowcount")) + self.assertTrue(hasattr(cursor, "callproc")) + self.assertTrue(hasattr(cursor, "close")) + self.assertTrue(hasattr(cursor, "execute")) + self.assertTrue(hasattr(cursor, "executemany")) + self.assertTrue(hasattr(cursor, "fetchone")) + self.assertTrue(hasattr(cursor, "fetchmany")) + self.assertTrue(hasattr(cursor, "fetchall")) + self.assertTrue(hasattr(cursor, "nextset")) + self.assertTrue(hasattr(cursor, "arraysize")) + self.assertTrue(hasattr(cursor, "setinputsizes")) + self.assertTrue(hasattr(cursor, "setoutputsize")) + + # Test iterator + self.assertTrue(hasattr(cursor, "__iter__")) + self.assertTrue(hasattr(cursor, "__next__")) + + # Test callproc raising NotSupportedError (mandatory by + # default unless implemented) + with self.assertRaises(self.errors.NotSupportedError): + cursor.callproc("proc") + + def test_description(self): + """Test cursor.description.""" + con = self._connect() + try: + cur = con.cursor() + cur.execute(self.sql_factory.stmt_ddl_create_table1) + + self.assertEqual( + cur.description, + None, + "cursor.description should be none after executing a " + "statement that can return no rows (such as DDL)", + ) + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + self.assertEqual( + len(cur.description), + 1, + "cursor.description describes too many columns", + ) + self.assertEqual( + len(cur.description[0]), + 7, + "cursor.description[x] tuples must have 7 elements", + ) + self.assertEqual( + cur.description[0][0].lower(), + "name", + "cursor.description[x][0] must return column name", + ) + self.assertEqual( + cur.description[0][1], + self.driver.STRING, + "cursor.description[x][1] must return column type. Got %r" + % cur.description[0][1], + ) + + # Make sure self.description gets reset + cur.execute(self.sql_factory.stmt_ddl_create_table2) + self.assertEqual( + cur.description, + None, + "cursor.description not being set to None when executing " + "no-result statements (eg. DDL)", + ) + finally: + con.close() + + def test_rowcount(self): + """Test cursor.rowcount.""" + con = self._connect() + try: + cur = con.cursor() + cur.execute(self.sql_factory.stmt_ddl_create_table1) + self.assertTrue( + cur.rowcount in (-1, 0), # Bug #543885 + "cursor.rowcount should be -1 or 0 after executing no-result " + "statements", + ) + cur.execute( + self.sql_factory.stmt_dml_insert_table1("1, 'Innocent Alice', 100") + ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number or rows inserted, or " + "set to -1 after executing an insert statement", + ) + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) + cur.execute(self.sql_factory.stmt_ddl_create_table2) + self.assertTrue( + cur.rowcount in (-1, 0), # Bug #543885 + "cursor.rowcount should be -1 or 0 after executing no-result " + "statements", + ) + finally: + con.close() + + def test_callproc(self): + """Test cursor.callproc().""" + con = self._connect() + try: + cur = con.cursor() + if self.lower_func and hasattr(cur, "callproc"): + r = cur.callproc(self.lower_func, ("FOO",)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], "FOO") + r = cur.fetchall() + self.assertEqual(len(r), 1, "callproc produced no result set") + self.assertEqual(len(r[0]), 1, "callproc produced invalid result set") + self.assertEqual(r[0][0], "foo", "callproc produced invalid results") + except self.driver.NotSupportedError: + pass + finally: + con.close() + + def test_execute(self): + """Test cursor.execute().""" + con = self._connect() + try: + cur = con.cursor() + self._simple_queries(cur) + finally: + con.close() + + @unittest.skip("Failing as params are not yet handled") + def test_execute_with_params(self): + """Test cursor.execute() with parameters.""" + con = self._connect() + try: + cur = con.cursor() + self._parametized_queries(cur) + finally: + con.close() + + @unittest.skip("Failing as params are not yet handled") + def test_executemany_with_params(self): + """Test cursor.executemany() with parameters.""" + con = self._connect() + try: + cur = con.cursor() + # DDL + cur.execute(self.sql_factory.stmt_ddl_create_table2) + + largs = [("Moms Lasagna",), ("Chocolate Brownie",)] + margs = [{"name": "Moms Lasagna"}, {"name": "Chocolate Brownie"}] + if self.driver.paramstyle == "qmark": + cur.executemany( + self.sql_factory.stmt_dml_insert_table2( + "102, ?, 1, True, 'thi%%s :may ca%%(u)se? troub:1e'" + ), + largs, + ) + elif self.driver.paramstyle == "numeric": + cur.executemany( + self.sql_factory.stmt_dml_insert_table2( + "102, :1, 1, True,'thi%%s :may ca%%(u)se? troub:1e'" + ), + largs, + ) + elif self.driver.paramstyle == "named": + cur.executemany( + self.sql_factory.stmt_dml_insert_table2( + "102, :item_name, 1, True, 'thi%%s :may ca%%(u)se? troub:1e'" + ), + margs, + ) + elif self.driver.paramstyle == "format": + cur.executemany( + self.sql_factory.stmt_dml_insert_table2( + "102, %%s, 1, True, 'thi%%%%s :may ca%%%%(u)se? troub:1e'" + ), + largs, + ) + elif self.driver.paramstyle == "pyformat": + cur.executemany( + self.sql_factory.stmt_dml_insert_table2( + "102, %%(item_name), 1, True, " + "'thi%%%%s :may ca%%%%(u)se? troub:1e'" + ), + margs, + ) + else: + self.fail("Unknown paramstyle") + + self.assertTrue( + cur.rowcount in (-1, 2), + "insert using cursor.executemany set cursor.rowcount to " + "incorrect value %r" % cur.rowcount, + ) + + # DQL + cur.execute(self.sql_factory.stmt_dql_select_all_table2()) + rows = cur.fetchall() + self.assertEqual( + len(rows), + 2, + "cursor.fetchall retrieved incorrect number of rows", + ) + item_names = [rows[0][1], rows[1][1]] + item_names.sort() + self.assertEqual( + item_names[0], + "Chocolate Brownie", + "cursor.fetchall retrieved incorrect data, or data inserted " + "incorrectly", + ) + self.assertEqual( + item_names[1], + "Moms Lasagna", + "cursor.fetchall retrieved incorrect data, or data inserted " + "incorrectly", + ) + finally: + con.close() + + def test_fetchone(self): + """Test cursor.fetchone().""" + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchone should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error, cur.fetchone) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannot return rows + cur.execute(self.sql_factory.stmt_ddl_create_table1) + self.assertRaises(self.driver.Error, cur.fetchone) + + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannot return rows + cur.execute( + self.sql_factory.stmt_dml_insert_table1("1, 'Innocent Alice', 100") + ) + self.assertRaises(self.driver.Error, cur.fetchone) + + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + row = cur.fetchone() + self.assertEqual( + len(row), + 1, + "cursor.fetchone should have retrieved a single row", + ) + self.assertEqual( + row[0], + "Innocent Alice", + "cursor.fetchone retrieved incorrect data", + ) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if no more rows available", + ) + self.assertTrue(cur.rowcount in (-1, 1)) + finally: + con.close() + + def test_fetchmany(self): + """Test cursor.fetchmany().""" + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchmany should raise an Error if called without + # issuing a query + self.assertRaises(self.driver.Error, cur.fetchmany, 4) + + cur.execute(self.sql_factory.stmt_ddl_create_table1) + for sql in self.sql_factory.populate_table1(): + cur.execute(sql) + + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + + r = cur.fetchmany() + self.assertEqual( + len(r), + 1, + "cursor.fetchmany retrieved incorrect number of rows, " + "default of arraysize is one.", + ) + + cur.arraysize = 10 + r = cur.fetchmany(2) # Should get 3 rows + self.assertEqual( + len(r), 2, "cursor.fetchmany retrieved incorrect number of rows" + ) + + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual( + len(r), 2, "cursor.fetchmany retrieved incorrect number of rows" + ) + + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence after " + "results are exhausted", + ) + + self.assertTrue(cur.rowcount in (-1, 5)) + + # Same as above, using cursor.arraysize + cur.arraysize = 3 + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual( + len(r), 3, "cursor.arraysize not being honoured by fetchmany" + ) + + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r), 2) + + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r), 0) + + self.assertTrue(cur.rowcount in (-1, 5)) + + cur.arraysize = 5 + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + rows = cur.fetchmany() # Should get all rows + self.assertTrue(cur.rowcount in (-1, 5)) + self.assertEqual(len(rows), 5) + rows = [r[0] for r in rows] + rows.sort() + + # Make sure we get the right data back out + for i in range(0, 5): + self.assertEqual( + rows[i], + self.sql_factory.names_table1[i], + "incorrect data retrieved by cursor.fetchmany", + ) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual( + len(rows), + 0, + "cursor.fetchmany should return an empty sequence if " + "called after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, 5)) + + cur.execute(self.sql_factory.stmt_ddl_create_table2) + cur.execute(self.sql_factory.stmt_dql_select_cols_table2("item_name")) + rows = cur.fetchmany() # Should get empty sequence + self.assertEqual( + len(rows), + 0, + "cursor.fetchmany should return an empty sequence if " + "query retrieved no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) + + for sql in self.sql_factory.populate_table2(): + cur.execute(sql) + + cur.execute(self.sql_factory.stmt_dql_select_cols_table2("item_name")) + cur.arraysize = 10 + rows = cur.fetchmany() # Should get empty sequence + self.assertEqual(len(rows), 7) + self.assertTrue(cur.rowcount in (-1, 7)) + + finally: + con.close() + + def test_fetchall(self): + """Test cursor.fetchall().""" + con = self._connect() + try: + cur = con.cursor() + # cursor.fetchall should raise an Error if called + # without executing a query that may return rows (such + # as a select) + self.assertRaises(self.driver.Error, cur.fetchall) + + cur.execute(self.sql_factory.stmt_ddl_create_table1) + for sql in self.sql_factory.populate_table1(): + cur.execute(sql) + + # cursor.fetchall should raise an Error if called + # after executing a a statement that cannot return rows + self.assertRaises(self.driver.Error, cur.fetchall) + + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + rows = cur.fetchall() + self.assertTrue(cur.rowcount in (-1, len(self.sql_factory.names_table1))) + self.assertEqual( + len(rows), + len(self.sql_factory.names_table1), + "cursor.fetchall did not retrieve all rows", + ) + rows = [r[0] for r in rows] + rows.sort() + for i in range(0, len(self.sql_factory.names_table1)): + self.assertEqual( + rows[i], + self.sql_factory.names_table1[i], + "cursor.fetchall retrieved incorrect rows", + ) + rows = cur.fetchall() + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, len(self.sql_factory.names_table1))) + + cur.execute(self.sql_factory.stmt_ddl_create_table2) + cur.execute(self.sql_factory.stmt_dql_select_cols_table2("item_name")) + rows = cur.fetchall() + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " + "a select query returns no rows", + ) + + finally: + con.close() + + def test_mixedfetch(self): + """Test mixing fetchone, fetchmany, and fetchall.""" + con = self._connect() + try: + cur = con.cursor() + cur.execute(self.sql_factory.stmt_ddl_create_table1) + for sql in self.sql_factory.populate_table1(): + cur.execute(sql) + + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + + rows1 = cur.fetchone() + rows23 = cur.fetchmany(2) + rows4 = cur.fetchone() + rows5 = cur.fetchall() + + self.assertTrue(cur.rowcount in (-1, len(self.sql_factory.names_table1))) + self.assertEqual( + len(rows23), 2, "fetchmany returned incorrect number of rows" + ) + self.assertEqual( + len(rows5), 1, "fetchall returned incorrect number of rows" + ) + + rows = [rows1[0]] + rows.extend([rows23[0][0], rows23[1][0]]) + rows.append(rows4[0]) + rows.extend([rows5[0][0]]) + rows.sort() + for i in range(0, len(self.sql_factory.names_table1)): + self.assertEqual( + rows[i], + self.sql_factory.names_table1[i], + "incorrect data retrieved or inserted", + ) + finally: + con.close() + + def help_nextset_setUp(self, cur): + sql = "SELECT 1; SELECT 2;" + cur.execute(sql) + + def help_nextset_tearDown(self, cur): + pass + + def test_nextset(self): + """Test cursor.nextset().""" + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur, "nextset"): + return + + try: + self.help_nextset_setUp(cur) + rows = cur.fetchone() + self.assertEqual(len(rows), 1) + s = cur.nextset() + self.assertEqual(s, True, "Has more return sets, should return True") + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + + def test_no_nextset(self): + """Test cursor.nextset() when no more sets exist.""" + con = self._connect() + try: + cur = con.cursor() + sql = "SELECT 1;" + cur.execute(sql) + if not hasattr(cur, "nextset"): + return + + try: + rows = cur.fetchone() + self.assertEqual(len(rows), 1) + s = cur.nextset() + self.assertEqual(s, None, "No more return sets, should return None") + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + + def test_arraysize(self): + """Test cursor.arraysize.""" + # Not much here - rest of the tests for this are in test_fetchmany + con = self._connect() + try: + cur = con.cursor() + self.assertTrue( + hasattr(cur, "arraysize"), + "cursor.arraysize must be defined", + ) + finally: + con.close() + + def test_setinputsizes(self): + """Test cursor.setinputsizes().""" + con = self._connect() + try: + cur = con.cursor() + cur.setinputsizes((25,)) + self._simple_queries(cur) # Make sure cursor still works + finally: + con.close() + + def test_setoutputsize_basic(self): + """Test cursor.setoutputsize().""" + # Basic test is to make sure setoutputsize doesn't blow up + con = self._connect() + try: + cur = con.cursor() + cur.setoutputsize(1000) + cur.setoutputsize(2000, 0) + self._simple_queries(cur) # Make sure the cursor still works + finally: + con.close() + + def test_setoutputsize(self): + """Extended test for cursor.setoutputsize() (optional).""" + # Real test for setoutputsize is driver dependant + raise NotImplementedError("Driver needed to override this test") + + def test_None(self): + """Test unpacking of NULL values.""" + con = self._connect() + try: + cur = con.cursor() + cur.execute(self.sql_factory.stmt_ddl_create_table1) + # inserting NULL to the second column, because some drivers might + # need the first one to be primary key, which means it needs + # to have a non-NULL value + cur.execute(self.sql_factory.stmt_dml_insert_table1("1, NULL, 100")) + cur.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) + row = cur.fetchone() + self.assertEqual(len(row), 1) + self.assertEqual(row[0], None, "NULL value not returned as None") + finally: + con.close() + + # ========================================================================= + # Type Objects and Constructors + # ========================================================================= + + def test_type_objects(self): + """Test type objects (STRING, BINARY, etc.).""" + self.assertTrue(hasattr(self.driver, "STRING")) + self.assertTrue(hasattr(self.driver, "BINARY")) + self.assertTrue(hasattr(self.driver, "NUMBER")) + self.assertTrue(hasattr(self.driver, "DATETIME")) + self.assertTrue(hasattr(self.driver, "ROWID")) + + def test_constructors(self): + """Test type constructors (Date, Time, etc.).""" + self.assertTrue(hasattr(self.driver, "Date")) + self.assertTrue(hasattr(self.driver, "Time")) + self.assertTrue(hasattr(self.driver, "Timestamp")) + self.assertTrue(hasattr(self.driver, "DateFromTicks")) + self.assertTrue(hasattr(self.driver, "TimeFromTicks")) + self.assertTrue(hasattr(self.driver, "TimestampFromTicks")) + self.assertTrue(hasattr(self.driver, "Binary")) + + def test_Date(self): + """Test Date constructor.""" + d1 = self.driver.Date(2002, 12, 25) + d2 = self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) + # Can we assume this? API doesn't specify, but it seems implied + self.assertEqual(str(d1), str(d2)) + + def test_Time(self): + """Test Time constructor.""" + # 1. Create the target time + t1 = self.driver.Time(13, 45, 30) + + # 2. Create ticks using Local Time (mktime is local) + # We use a dummy date (2001-01-01) + target_tuple = (2001, 1, 1, 13, 45, 30, 0, 0, 0) + ticks = time.mktime(target_tuple) + + t2 = self.driver.TimeFromTicks(ticks) + + # CHECK 1: Ensure they are the same type (likely datetime.time) + self.assertIsInstance(t1, type(t2)) + + # CHECK 2: Compare value semantics, not string representation + # This avoids format differences but still requires timezone alignment + self.assertEqual(t1, t2) + + def test_Timestamp(self): + """Test Timestamp constructor.""" + t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) + t2 = self.driver.TimestampFromTicks( + time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) + ) + # Can we assume this? API doesn't specify, but it seems implied + self.assertEqual(str(t1), str(t2)) + + def test_Binary(self): + """Test Binary constructor.""" + s = "Something" + b = self.driver.Binary(encode(s)) + self.assertEqual(s, decode(b)) + + def test_STRING(self): + """Test STRING type object.""" + self.assertTrue(hasattr(self.driver, "STRING"), "module.STRING must be defined") + + def test_BINARY(self): + """Test BINARY type object.""" + self.assertTrue( + hasattr(self.driver, "BINARY"), "module.BINARY must be defined." + ) + + def test_NUMBER(self): + """Test NUMBER type object.""" + self.assertTrue( + hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined." + ) + + def test_DATETIME(self): + """Test DATETIME type object.""" + self.assertTrue( + hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined." + ) + + def test_ROWID(self): + """Test ROWID type object.""" + self.assertTrue(hasattr(self.driver, "ROWID"), "module.ROWID must be defined.") diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/sql_factory.py b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/sql_factory.py new file mode 100644 index 000000000000..7420727bdd13 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/sql_factory.py @@ -0,0 +1,204 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +""" +Scenario: The Office Fridge Wars. +This scenario tracks the high-stakes drama of shared office lunches. + +TABLE 1: coworkers +| id | name | trust_level | +--------------------------------------- +| 1 | 'Innocent Alice' | 100 | +| 2 | 'Vegan Sarah' | 95 | +| 3 | 'Manager Bob' | 50 | +| 4 | 'Intern Kevin' | 15 | +| 5 | 'Suspicious Dave'| -10 | + +TABLE 2: office_fridge +| item_id | item_name | owner_id | is_stolen | notes | +--------------------------------------------------------------------------- +-- Alice's perfectly prepped meals (High theft targets) +| 101 | 'Moms Lasagna' | 1 | True | "" | +| 102 | 'Chocolate Brownie' | 1 | True | "" | +-- Sarah's food (Safe because it's Kale) +| 103 | 'Kale & Quinoa Bowl' | 2 | False | "" | +-- Manager Bob's lunch (Too fancy to steal?) +| 104 | 'Expensive Sushi' | 3 | False | "" | +-- Kevin's drink (The only thing he brought) +| 105 | 'Mega Energy Drink' | 4 | True | "" | +-- Dave's mystery food (No one dares touch it) +| 106 | 'Unlabeled Tupperware Sludge' | 5 | False | "" | +-- Alice's sandwich (The label makes it a dare - Trap?) +| 107 | 'Sandwich labeled - Do Not Eat'| 1 | True | "" | +""" + + +class SQLFactory(abc.ABC): + TABLE_PREFIX = "spd20_" + TABLE1 = "coworkers" + TABLE1_COLS = "id, name, trust_level" + TABLE2 = "office_fridge" + TABLE2_COLS = "item_id, item_name, owner_id, is_stolen, notes" + SELECT_1 = "SELECT 1" + + @property + def table1(self): + return self.TABLE_PREFIX + self.TABLE1 + + @property + def table2(self): + return self.TABLE_PREFIX + self.TABLE2 + + @property + def stmt_dql_select_1(self): + return self.SELECT_1 + + @property + @abc.abstractmethod + def stmt_ddl_create_table1(self): + pass + + @property + @abc.abstractmethod + def stmt_ddl_create_table2(self): + pass + + @property + def stmt_ddl_drop_all_cmds(self): + return [self.stmt_ddl_drop_table1, self.stmt_ddl_drop_table2] + + @property + def stmt_ddl_drop_table1(self): + return "DROP TABLE %s" % (self.table1) + + @property + def stmt_ddl_drop_table2(self): + return "DROP TABLE %s" % (self.table2) + + def stmt_dql_select_all(self, table): + return "SELECT * FROM %s" % (table) + + def stmt_dql_select_all_table1(self): + return self.stmt_dql_select_all(self.table1) + + def stmt_dql_select_all_table2(self): + return self.stmt_dql_select_all(self.table2) + + def stmt_dql_select_cols(self, table, col): + return "SELECT (%s) FROM %s" % (col, table) + + def stmt_dql_select_cols_table1(self, col): + return self.stmt_dql_select_cols(self.table1, col) + + def stmt_dql_select_cols_table2(self, col): + return self.stmt_dql_select_cols(self.table2, col) + + def stmt_dml_insert(self, table, cols, vals): + return "INSERT INTO %s (%s) VALUES (%s)" % (table, cols, vals) + + def stmt_dml_insert_table1(self, vals): + return self.stmt_dml_insert(self.table1, self.TABLE1_COLS, vals) + + def stmt_dml_insert_table2(self, vals): + return self.stmt_dml_insert(self.table2, self.TABLE2_COLS, vals) + + sample_table1 = [ + [1, "Innocent Alice", 100], + [2, "Vegan Sarah", 95], + [3, "Manager Bob", 50], + [4, "Intern Kevin", 15], + [5, "Suspicious Dave", -10], + ] + names_table1 = sorted([row[1] for row in sample_table1]) + + def process_row(self, row): + def to_sql_literal(value): + # Check for boolean first + if isinstance(value, bool): + return "TRUE" if value else "FALSE" + # Wrap strings in single quotes + elif isinstance(value, str): + return f"'{value}'" + # Return numbers and other types as-is + else: + return str(value) + + return ", ".join(map(to_sql_literal, row)) + + def populate_table1(self): + return [ + self.stmt_dml_insert_table1(self.process_row(row)) + for row in self.sample_table1 + ] + + sample_table2 = [ + [101, "Mystery Sandwich", 1, True, ""], + [102, "Leftover Pizza", 2, True, ""], + [103, "Kale & Quinoa Bowl", 3, False, ""], + [104, "Expensive Sushi", 4, False, ""], + [105, "Mega Energy Drink", 5, True, ""], + [106, "Unlabeled Tupperware Sludge", 6, False, ""], + [107, "Sandwich labeled - Do Not Eat", 7, True, ""], + ] + item_names_table2 = sorted([row[1] for row in sample_table2]) + + def populate_table2(self): + return [ + self.stmt_dml_insert_table2(self.process_row(row)) + for row in self.sample_table2 + ] + + @staticmethod + def get_factory(dialect): + if dialect == "PostgreSQL": + return PostgreSQLFactory() + elif dialect == "GoogleSQL": + return GoogleSQLFactory() + else: + raise ValueError("Unknown dialect: %s" % dialect) + + +class GoogleSQLFactory(SQLFactory): + @property + def stmt_ddl_create_table1(self): + return ( + "CREATE TABLE %s%s " + "(id INT64, name STRING(100), trust_level INT64) " + "PRIMARY KEY (id)" % (self.TABLE_PREFIX, self.TABLE1) + ) + + @property + def stmt_ddl_create_table2(self): + return ( + "CREATE TABLE %s%s " + "(item_id INT64, item_name STRING(100), " + "owner_id INT64, is_stolen BOOL, notes STRING(100)) " + "PRIMARY KEY (item_id)" % (self.TABLE_PREFIX, self.TABLE2) + ) + + +class PostgreSQLFactory(SQLFactory): + @property + def stmt_ddl_create_table1(self): + raise NotImplementedError( + "PostgreSQL dialect support is not yet implemented..." + ) + + @property + def stmt_ddl_create_table2(self): + raise NotImplementedError( + "PostgreSQL dialect support is not yet implemented..." + ) diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/test_compliance.py b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/test_compliance.py new file mode 100644 index 000000000000..14e5726e0c13 --- /dev/null +++ b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/test_compliance.py @@ -0,0 +1,43 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DBAPI 2.0 Compliance Test +Checks for presence of required attributes and methods. +""" + +import os +import unittest + +from google.cloud import spanner_driver +from google.cloud.spanner_driver import errors + +from ._helper import get_test_connection_string +from .dbapi20_compliance_testbase import DBAPI20ComplianceTestBase + + +class TestDBAPICompliance(DBAPI20ComplianceTestBase): + __test__ = True + driver = spanner_driver + errors = errors + connect_args = (get_test_connection_string(),) + connect_kw_args = {} + dialect = os.environ.get("TEST_DIALECT", "GoogleSQL") + + def test_setoutputsize(self): + pass + + +if __name__ == "__main__": + unittest.main() From e5aa6e382d80870d36148c1376e1514f8694be63 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Sat, 28 Mar 2026 08:30:46 +0000 Subject: [PATCH 11/12] test: update compliance tests to expect OperationalError on commit and rollback operations --- .../compliance/dbapi20_compliance_testbase.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py index 390ec0a2adf2..bd918edcc117 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py @@ -308,8 +308,8 @@ def test_commit(self): """Test connection.commit().""" con = self._connect() try: - # Commit must work, even if it doesn't do anything - con.commit() + with self.assertRaises(self.errors.OperationalError): + con.commit() finally: con.close() @@ -321,7 +321,8 @@ def test_rollback(self): # the documented exception if hasattr(con, "rollback"): try: - con.rollback() + with self.assertRaises(self.errors.OperationalError): + con.rollback() except self.driver.NotSupportedError: pass finally: @@ -346,11 +347,17 @@ def test_cursor_isolation(self): cur2 = con.cursor() cur1.execute(self.sql_factory.stmt_ddl_create_table1) # DDL usually requires a clean slate or commit in some test envs - con.commit() + try: + con.commit() + except self.errors.OperationalError: + pass cur1.execute( self.sql_factory.stmt_dml_insert_table1("1, 'Innocent Alice', 100") ) - con.commit() + try: + con.commit() + except self.errors.OperationalError: + pass cur2.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) users = cur2.fetchone() From 98987be01c919306aef5e252ed0b45ac9d31cda3 Mon Sep 17 00:00:00 2001 From: Sanjeev Bhatt Date: Mon, 30 Mar 2026 07:06:08 +0000 Subject: [PATCH 12/12] chore: improve logging in compliance tests and update noxfile environment configuration --- .../noxfile.py | 5 ++- .../tests/compliance/_helper.py | 2 +- .../compliance/dbapi20_compliance_testbase.py | 33 +++++++++++-------- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/packages/google-cloud-spanner-dbapi-driver/noxfile.py b/packages/google-cloud-spanner-dbapi-driver/noxfile.py index c26d2f0f6a2d..552e14b1427c 100644 --- a/packages/google-cloud-spanner-dbapi-driver/noxfile.py +++ b/packages/google-cloud-spanner-dbapi-driver/noxfile.py @@ -364,7 +364,10 @@ def compliance(session): MODE, f"--junitxml=compliance_{session.python}_sponge_log.xml", *test_paths, - env={}, + env={ + "SPANNER_EMULATOR_HOST": os.environ["SPANNER_EMULATOR_HOST"], + "TEST_DIALECT": os.environ.get("TEST_DIALECT", "GoogleSQL"), + }, ) diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/_helper.py b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/_helper.py index 45968fa9d1fa..0c525c32315a 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/_helper.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/_helper.py @@ -31,7 +31,7 @@ def setup_test_env() -> None: - print(f"Set SPANNER_EMULATOR_HOST to {os.environ['SPANNER_EMULATOR_HOST']}") + print(f"Set SPANNER_EMULATOR_HOST to {SPANNER_EMULATOR_HOST}") print(f"Using Connection String: {get_test_connection_string()}") diff --git a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py index bd918edcc117..bd4bd7d0697a 100644 --- a/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py +++ b/packages/google-cloud-spanner-dbapi-driver/tests/compliance/dbapi20_compliance_testbase.py @@ -16,12 +16,15 @@ DBAPI 2.0 Compliance Test """ +import logging import time import unittest from unittest.mock import MagicMock from .sql_factory import SQLFactory +logger = logging.getLogger(__name__) + def encode(s: str) -> bytes: return s.encode("utf-8") @@ -59,6 +62,7 @@ def tearDownClass(cls): pass def setUp(self): + logger.info("Executing test: %s", self.id()) self.cleanup() def tearDown(self): @@ -73,19 +77,22 @@ def cleanup(self): try: cur.execute(ddl) con.commit() - except self.driver.Error: + except self.driver.Error as e: # Assume table didn't exist. Other tests will check if # execute is busted. - pass + logger.debug( + "Cleanup DDL failed (expected if table missing): %s", e + ) finally: con.close() - except Exception: - pass + except Exception as e: + logger.warning("Cleanup failed with exception: %s", e) def _connect(self): try: r = self.driver.connect(*self.connect_args, **self.connect_kw_args) - except AttributeError: + except AttributeError as e: + logger.error("No connect method found in self.driver module: %s", e) self.fail("No connect method found in self.driver module") return r @@ -323,8 +330,8 @@ def test_rollback(self): try: with self.assertRaises(self.errors.OperationalError): con.rollback() - except self.driver.NotSupportedError: - pass + except self.driver.NotSupportedError as e: + logger.debug("Rollback not supported (expected): %s", e) finally: con.close() @@ -349,15 +356,15 @@ def test_cursor_isolation(self): # DDL usually requires a clean slate or commit in some test envs try: con.commit() - except self.errors.OperationalError: - pass + except self.errors.OperationalError as e: + logger.debug("Empty commit threw expected OperationalError: %s", e) cur1.execute( self.sql_factory.stmt_dml_insert_table1("1, 'Innocent Alice', 100") ) try: con.commit() - except self.errors.OperationalError: - pass + except self.errors.OperationalError as e: + logger.debug("Insert commit threw expected OperationalError: %s", e) cur2.execute(self.sql_factory.stmt_dql_select_cols_table1("name")) users = cur2.fetchone() @@ -492,8 +499,8 @@ def test_callproc(self): self.assertEqual(len(r), 1, "callproc produced no result set") self.assertEqual(len(r[0]), 1, "callproc produced invalid result set") self.assertEqual(r[0][0], "foo", "callproc produced invalid results") - except self.driver.NotSupportedError: - pass + except self.driver.NotSupportedError as e: + logger.debug("callproc not supported (expected): %s", e) finally: con.close()