Skip to content

Commit

Permalink
feat(db_api): support stale reads (#584)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ilya Gurov authored Nov 13, 2021
1 parent 63f6572 commit 8ca868c
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 15 deletions.
41 changes: 40 additions & 1 deletion google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, instance, database, read_only=False):
# connection close
self._own_pool = True
self._read_only = read_only
self._staleness = None

@property
def autocommit(self):
Expand Down Expand Up @@ -165,6 +166,42 @@ def read_only(self, value):
)
self._read_only = value

@property
def staleness(self):
"""Current read staleness option value of this `Connection`.
Returns:
dict: Staleness type and value.
"""
return self._staleness or {}

@staleness.setter
def staleness(self, value):
"""Read staleness option setter.
Args:
value (dict): Staleness type and value.
"""
if self.inside_transaction:
raise ValueError(
"`staleness` option can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
)

possible_opts = (
"read_timestamp",
"min_read_timestamp",
"max_staleness",
"exact_staleness",
)
if value is not None and sum([opt in value for opt in possible_opts]) != 1:
raise ValueError(
"Expected one of the following staleness options: "
"read_timestamp, min_read_timestamp, max_staleness, exact_staleness."
)

self._staleness = value

def _session_checkout(self):
"""Get a Cloud Spanner session from the pool.
Expand Down Expand Up @@ -284,7 +321,9 @@ def snapshot_checkout(self):
"""
if self.read_only and not self.autocommit:
if not self._snapshot:
self._snapshot = Snapshot(self._session_checkout(), multi_use=True)
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
)
self._snapshot.begin()

return self._snapshot
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ def _handle_DQL(self, sql, params):
)
else:
# execute with single-use snapshot
with self.connection.database.snapshot() as snapshot:
with self.connection.database.snapshot(
**self.connection.staleness
) as snapshot:
self._handle_DQL_with_snapshot(snapshot, sql, params)

def __enter__(self):
Expand Down
34 changes: 33 additions & 1 deletion tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import hashlib
import pickle
import pkg_resources
import pytest

from google.cloud import spanner_v1
from google.cloud.spanner_dbapi.connection import connect, Connection
from google.cloud._helpers import UTC
from google.cloud.spanner_dbapi.connection import connect
from google.cloud.spanner_dbapi.connection import Connection
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_v1 import JsonObject
from . import _helpers
Expand Down Expand Up @@ -429,3 +432,32 @@ def test_read_only(shared_instance, dbapi_database):

cur.execute("SELECT * FROM contacts")
conn.commit()


def test_staleness(shared_instance, dbapi_database):
"""Check the DB API `staleness` option."""
conn = Connection(shared_instance, dbapi_database)
cursor = conn.cursor()

before_insert = datetime.datetime.utcnow().replace(tzinfo=UTC)

cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
"""
)
conn.commit()

conn.read_only = True
conn.staleness = {"read_timestamp": before_insert}
cursor.execute("SELECT * FROM contacts")
conn.commit()
assert len(cursor.fetchall()) == 0

conn.staleness = None
cursor.execute("SELECT * FROM contacts")
conn.commit()
assert len(cursor.fetchall()) == 1

conn.close()
130 changes: 118 additions & 12 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Cloud Spanner DB-API Connection class unit tests."""

import datetime
import mock
import unittest
import warnings
Expand Down Expand Up @@ -688,9 +689,6 @@ def test_retry_transaction_w_empty_response(self):
run_mock.assert_called_with(statement, retried=True)

def test_validate_ok(self):
def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
Expand All @@ -699,7 +697,7 @@ def exit_func(self, exc_type, exc_value, traceback):

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_ctx.__exit__ = exit_ctx_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method
Expand All @@ -710,9 +708,6 @@ def exit_func(self, exc_type, exc_value, traceback):
def test_validate_fail(self):
from google.cloud.spanner_dbapi.exceptions import OperationalError

def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
Expand All @@ -721,7 +716,7 @@ def exit_func(self, exc_type, exc_value, traceback):

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_ctx.__exit__ = exit_ctx_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method
Expand All @@ -734,9 +729,6 @@ def exit_func(self, exc_type, exc_value, traceback):
def test_validate_error(self):
from google.cloud.exceptions import NotFound

def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
Expand All @@ -745,7 +737,7 @@ def exit_func(self, exc_type, exc_value, traceback):

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_ctx.__exit__ = exit_ctx_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method
Expand All @@ -763,3 +755,117 @@ def test_validate_closed(self):

with self.assertRaises(InterfaceError):
connection.validate()

def test_staleness_invalid_value(self):
"""Check that `staleness` property accepts only correct values."""
connection = self._make_connection()

# incorrect staleness type
with self.assertRaises(ValueError):
connection.staleness = {"something": 4}

# no expected staleness types
with self.assertRaises(ValueError):
connection.staleness = {}

def test_staleness_inside_transaction(self):
"""
Check that it's impossible to change the `staleness`
option if a transaction is in progress.
"""
connection = self._make_connection()
connection._transaction = mock.Mock(committed=False, rolled_back=False)

with self.assertRaises(ValueError):
connection.staleness = {"read_timestamp": datetime.datetime(2021, 9, 21)}

def test_staleness_multi_use(self):
"""
Check that `staleness` option is correctly
sent to the `Snapshot()` constructor.
READ_ONLY, NOT AUTOCOMMIT
"""
timestamp = datetime.datetime(2021, 9, 20)

connection = self._make_connection()
connection._session = "session"
connection.read_only = True
connection.staleness = {"read_timestamp": timestamp}

with mock.patch(
"google.cloud.spanner_dbapi.connection.Snapshot"
) as snapshot_mock:
connection.snapshot_checkout()

snapshot_mock.assert_called_with(
"session", multi_use=True, read_timestamp=timestamp
)

def test_staleness_single_use_autocommit(self):
"""
Check that `staleness` option is correctly
sent to the snapshot context manager.
NOT READ_ONLY, AUTOCOMMIT
"""
timestamp = datetime.datetime(2021, 9, 20)

connection = self._make_connection()
connection._session_checkout = mock.MagicMock(autospec=True)

connection.autocommit = True
connection.staleness = {"read_timestamp": timestamp}

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(return_value=[1])

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_ctx_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

cursor = connection.cursor()
cursor.execute("SELECT 1")

connection.database.snapshot.assert_called_with(read_timestamp=timestamp)

def test_staleness_single_use_readonly_autocommit(self):
"""
Check that `staleness` option is correctly sent to the
snapshot context manager while in `autocommit` mode.
READ_ONLY, AUTOCOMMIT
"""
timestamp = datetime.datetime(2021, 9, 20)

connection = self._make_connection()
connection.autocommit = True
connection.read_only = True
connection._session_checkout = mock.MagicMock(autospec=True)

connection.staleness = {"read_timestamp": timestamp}

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(return_value=[1])

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_ctx_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

cursor = connection.cursor()
cursor.execute("SELECT 1")

connection.database.snapshot.assert_called_with(read_timestamp=timestamp)


def exit_ctx_func(self, exc_type, exc_value, traceback):
"""Context __exit__ method mock."""
pass

0 comments on commit 8ca868c

Please sign in to comment.