Skip to content
Permalink
Browse files
fix(db_api): move connection validation into a separate method (#543)
Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
  • Loading branch information
Ilya Gurov and larkee committed Sep 7, 2021
1 parent 23b1600 commit 237ae41d0c0db61f157755cf04f84ef2d146972c
@@ -28,7 +28,7 @@
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import InterfaceError
from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

@@ -349,6 +349,30 @@ def run_statement(self, statement, retried=False):
ResultsChecksum() if retried else statement.checksum,
)

def validate(self):
"""
Execute a minimal request to check if the connection
is valid and the related database is reachable.
Raise an exception in case if the connection is closed,
invalid, target database is not found, or the request result
is incorrect.
:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if the request result is incorrect.
:raises: :class:`google.cloud.exceptions.NotFound`: if the linked instance
or database doesn't exist.
"""
self._raise_if_closed()

with self.database.snapshot() as snapshot:
result = list(snapshot.execute_sql("SELECT 1"))
if result != [[1]]:
raise OperationalError(
"The checking query (SELECT 1) returned an unexpected result: %s. "
"Expected: [[1]]" % result
)

def __enter__(self):
return self

@@ -399,9 +423,6 @@ def connect(
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
resource.
:raises: :class:`ValueError` in case of given instance/database
doesn't exist.
"""

client_info = ClientInfo(
@@ -418,14 +439,7 @@ def connect(
)

instance = client.instance(instance_id)
if not instance.exists():
raise ValueError("instance '%s' does not exist." % instance_id)

database = instance.database(database_id, pool=pool)
if not database.exists():
raise ValueError("database '%s' does not exist." % database_id)

conn = Connection(instance, database)
conn = Connection(instance, instance.database(database_id, pool=pool))
if pool is not None:
conn._own_pool = False

@@ -350,3 +350,10 @@ def test_DDL_commit(shared_instance, dbapi_database):

cur.execute("DROP TABLE Singers")
conn.commit()


def test_ping(shared_instance, dbapi_database):
"""Check connection validation method."""
conn = Connection(shared_instance, dbapi_database)
conn.validate()
conn.close()
@@ -88,31 +88,6 @@ def test_w_explicit(self, mock_client):
self.assertIs(connection.database, database)
instance.database.assert_called_once_with(DATABASE, pool=pool)

def test_w_instance_not_found(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = mock_client.return_value
instance = client.instance.return_value
instance.exists.return_value = False

with self.assertRaises(ValueError):
connect(INSTANCE, DATABASE)

instance.exists.assert_called_once_with()

def test_w_database_not_found(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = mock_client.return_value
instance = client.instance.return_value
database = instance.database.return_value
database.exists.return_value = False

with self.assertRaises(ValueError):
connect(INSTANCE, DATABASE)

database.exists.assert_called_once_with()

def test_w_credential_file_path(self, mock_client):
from google.cloud.spanner_dbapi import connect
from google.cloud.spanner_dbapi import Connection
@@ -624,3 +624,80 @@ def test_retry_transaction_w_empty_response(self):
compare_mock.assert_called_with(checksum, retried_checkum)

run_mock.assert_called_with(statement, retried=True)

def test_validate_ok(self):
def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(return_value=[[1]])

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

connection.validate()
snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")

def test_validate_fail(self):
from google.cloud.spanner_dbapi.exceptions import OperationalError

def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(return_value=[[3]])

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

with self.assertRaises(OperationalError):
connection.validate()

snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")

def test_validate_error(self):
from google.cloud.exceptions import NotFound

def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(side_effect=NotFound("Not found"))

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

with self.assertRaises(NotFound):
connection.validate()

snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")

def test_validate_closed(self):
from google.cloud.spanner_dbapi.exceptions import InterfaceError

connection = self._make_connection()
connection.close()

with self.assertRaises(InterfaceError):
connection.validate()
@@ -332,13 +332,7 @@ def test_executemany_delete_batch_autocommit(self):

sql = "DELETE FROM table WHERE col1 = %s"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
transaction = self._transaction_mock()
@@ -369,13 +363,7 @@ def test_executemany_update_batch_autocommit(self):

sql = "UPDATE table SET col1 = %s WHERE col2 = %s"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
transaction = self._transaction_mock()
@@ -418,13 +406,7 @@ def test_executemany_insert_batch_non_autocommit(self):

sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

transaction = self._transaction_mock()

@@ -461,13 +443,7 @@ def test_executemany_insert_batch_autocommit(self):

sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True

@@ -510,13 +486,7 @@ def test_executemany_insert_batch_failed(self):
sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""
err_details = "Details here"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
cursor = connection.cursor()
@@ -546,13 +516,7 @@ def test_executemany_insert_batch_aborted(self):
sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""
err_details = "Aborted details here"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

transaction1 = mock.Mock(committed=False, rolled_back=False)
transaction1.batch_update = mock.Mock(

0 comments on commit 237ae41

Please sign in to comment.