diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index d251e0f62a..a50e48804b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -83,7 +83,7 @@ class Connection: should end a that a new one should be started when the next statement is executed. """ - def __init__(self, instance, database, read_only=False): + def __init__(self, instance, database=None, read_only=False): self._instance = instance self._database = database self._ddl_statements = [] @@ -242,6 +242,8 @@ def _session_checkout(self): :rtype: :class:`google.cloud.spanner_v1.session.Session` :returns: Cloud Spanner session object ready to use. """ + if self.database is None: + raise ValueError("Database needs to be passed for this operation") if not self._session: self._session = self.database._pool.get() @@ -252,6 +254,8 @@ def _release_session(self): The session will be returned into the sessions pool. """ + if self.database is None: + raise ValueError("Database needs to be passed for this operation") self.database._pool.put(self._session) self._session = None @@ -368,7 +372,7 @@ def close(self): if self.inside_transaction: self._transaction.rollback() - if self._own_pool: + if self._own_pool and self.database: self.database._pool.clear() self.is_closed = True @@ -378,6 +382,8 @@ def commit(self): This method is non-operational in autocommit mode. """ + if self.database is None: + raise ValueError("Database needs to be passed for this operation") self._snapshot = None if self._autocommit: @@ -420,6 +426,8 @@ def cursor(self): @check_not_closed def run_prior_DDL_statements(self): + if self.database is None: + raise ValueError("Database needs to be passed for this operation") if self._ddl_statements: ddl_statements = self._ddl_statements self._ddl_statements = [] @@ -474,6 +482,8 @@ def validate(self): :raises: :class:`google.cloud.exceptions.NotFound`: if the linked instance or database doesn't exist. """ + if self.database is None: + raise ValueError("Database needs to be passed for this operation") with self.database.snapshot() as snapshot: result = list(snapshot.execute_sql("SELECT 1")) if result != [[1]]: @@ -492,7 +502,7 @@ def __exit__(self, etype, value, traceback): def connect( instance_id, - database_id, + database_id=None, project=None, credentials=None, pool=None, @@ -505,7 +515,7 @@ def connect( :param instance_id: The ID of the instance to connect to. :type database_id: str - :param database_id: The ID of the database to connect to. + :param database_id: (Optional) The ID of the database to connect to. :type project: str :param project: (Optional) The ID of the project which owns the @@ -557,7 +567,9 @@ def connect( raise ValueError("project in url does not match client object project") instance = client.instance(instance_id) - conn = Connection(instance, instance.database(database_id, pool=pool)) + conn = Connection( + instance, instance.database(database_id, pool=pool) if database_id else None + ) if pool is not None: conn._own_pool = False diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index ac3888f35d..91bccedd4c 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -228,6 +228,8 @@ def execute(self, sql, args=None): :type args: list :param args: Additional parameters to supplement the SQL query. """ + if self.connection.database is None: + raise ValueError("Database needs to be passed for this operation") self._itr = None self._result_set = None self._row_count = _UNSET_COUNT @@ -301,6 +303,8 @@ def executemany(self, operation, seq_of_params): :param seq_of_params: Sequence of additional parameters to run the query with. """ + if self.connection.database is None: + raise ValueError("Database needs to be passed for this operation") self._itr = None self._result_set = None self._row_count = _UNSET_COUNT @@ -444,6 +448,8 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params): self._row_count = _UNSET_COUNT def _handle_DQL(self, sql, params): + if self.connection.database is None: + raise ValueError("Database needs to be passed for this operation") sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) if self.connection.read_only and not self.connection.autocommit: # initiate or use the existing multi-use snapshot @@ -484,6 +490,8 @@ def list_tables(self): def run_sql_in_snapshot(self, sql, params=None, param_types=None): # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions # hence this method exists to circumvent that limit. + if self.connection.database is None: + raise ValueError("Database needs to be passed for this operation") self.connection.run_prior_DDL_statements() with self.connection.database.snapshot() as snapshot: diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index b077c1feba..7a0ac9e687 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -169,6 +169,14 @@ def test__session_checkout(self, mock_database): connection._session_checkout() self.assertEqual(connection._session, "db_session") + def test__session_checkout_database_error(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(INSTANCE) + + with pytest.raises(ValueError): + connection._session_checkout() + @mock.patch("google.cloud.spanner_v1.database.Database") def test__release_session(self, mock_database): from google.cloud.spanner_dbapi import Connection @@ -182,6 +190,13 @@ def test__release_session(self, mock_database): pool.put.assert_called_once_with("session") self.assertIsNone(connection._session) + def test__release_session_database_error(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(INSTANCE) + with pytest.raises(ValueError): + connection._release_session() + def test_transaction_checkout(self): from google.cloud.spanner_dbapi import Connection @@ -294,6 +309,14 @@ def test_commit(self, mock_warn): AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 ) + def test_commit_database_error(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(INSTANCE) + + with pytest.raises(ValueError): + connection.commit() + @mock.patch.object(warnings, "warn") def test_rollback(self, mock_warn): from google.cloud.spanner_dbapi import Connection @@ -347,6 +370,13 @@ def test_run_prior_DDL_statements(self, mock_database): with self.assertRaises(InterfaceError): connection.run_prior_DDL_statements() + def test_run_prior_DDL_statements_database_error(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(INSTANCE) + with pytest.raises(ValueError): + connection.run_prior_DDL_statements() + def test_as_context_manager(self): connection = self._make_connection() with connection as conn: @@ -766,6 +796,14 @@ def test_validate_error(self): snapshot_obj.execute_sql.assert_called_once_with("SELECT 1") + def test_validate_database_error(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(INSTANCE) + + with pytest.raises(ValueError): + connection.validate() + def test_validate_closed(self): from google.cloud.spanner_dbapi.exceptions import InterfaceError @@ -916,16 +954,14 @@ def test_request_priority(self): sql, params, param_types=param_types, request_options=None ) - @mock.patch("google.cloud.spanner_v1.Client") - def test_custom_client_connection(self, mock_client): + def test_custom_client_connection(self): from google.cloud.spanner_dbapi import connect client = _Client() connection = connect("test-instance", "test-database", client=client) self.assertTrue(connection.instance._client == client) - @mock.patch("google.cloud.spanner_v1.Client") - def test_invalid_custom_client_connection(self, mock_client): + def test_invalid_custom_client_connection(self): from google.cloud.spanner_dbapi import connect client = _Client() @@ -937,6 +973,12 @@ def test_invalid_custom_client_connection(self, mock_client): client=client, ) + def test_connection_wo_database(self): + from google.cloud.spanner_dbapi import connect + + connection = connect("test-instance") + self.assertTrue(connection.database is None) + def exit_ctx_func(self, exc_type, exc_value, traceback): """Context __exit__ method mock.""" diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 79ed898355..f744fc769f 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -163,6 +163,13 @@ def test_execute_attribute_error(self): with self.assertRaises(AttributeError): cursor.execute(sql="SELECT 1") + def test_execute_database_error(self): + connection = self._make_connection(self.INSTANCE) + cursor = self._make_one(connection) + + with self.assertRaises(ValueError): + cursor.execute(sql="SELECT 1") + def test_execute_autocommit_off(self): from google.cloud.spanner_dbapi.utils import PeekIterator @@ -607,6 +614,16 @@ def test_executemany_insert_batch_aborted(self): ) self.assertIsInstance(connection._statements[0][1], ResultsChecksum) + @mock.patch("google.cloud.spanner_v1.Client") + def test_executemany_database_error(self, mock_client): + from google.cloud.spanner_dbapi import connect + + connection = connect("test-instance") + cursor = connection.cursor() + + with self.assertRaises(ValueError): + cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ()) + @unittest.skipIf( sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" ) @@ -754,6 +771,13 @@ def test_handle_dql_priority(self): sql, None, None, request_options=RequestOptions(priority=1) ) + def test_handle_dql_database_error(self): + connection = self._make_connection(self.INSTANCE) + cursor = self._make_one(connection) + + with self.assertRaises(ValueError): + cursor._handle_DQL("sql", params=None) + def test_context(self): connection = self._make_connection(self.INSTANCE, self.DATABASE) cursor = self._make_one(connection) @@ -814,6 +838,13 @@ def test_run_sql_in_snapshot(self): mock_snapshot.execute_sql.return_value = results self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results)) + def test_run_sql_in_snapshot_database_error(self): + connection = self._make_connection(self.INSTANCE) + cursor = self._make_one(connection) + + with self.assertRaises(ValueError): + cursor.run_sql_in_snapshot("sql") + def test_get_table_column_schema(self): from google.cloud.spanner_dbapi.cursor import ColumnDetails from google.cloud.spanner_dbapi import _helpers