Skip to content

Commit

Permalink
feat: Leader Aware Routing (#899)
Browse files Browse the repository at this point in the history
* changes

* tests

* Update client.py

* Update test_client.py

* Update connection.py

* setting feature false

* changes
  • Loading branch information
asthamohta committed Apr 27, 2023
1 parent 10a1351 commit f9fefad
Show file tree
Hide file tree
Showing 21 changed files with 543 additions and 93 deletions.
19 changes: 17 additions & 2 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def connect(
pool=None,
user_agent=None,
client=None,
route_to_leader_enabled=False,
):
"""Creates a connection to a Google Cloud Spanner database.
Expand Down Expand Up @@ -544,6 +545,14 @@ def connect(
:class:`~google.cloud.spanner_v1.Client`.
:param client: (Optional) Custom user provided Client Object
:type route_to_leader_enabled: boolean
:param route_to_leader_enabled:
(Optional) Default False. Set route_to_leader_enabled as True to
Enable leader aware routing. Enabling leader aware routing
would route all requests in RW/PDML transactions to the
leader region.
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
resource.
Expand All @@ -556,11 +565,17 @@ def connect(
)
if isinstance(credentials, str):
client = spanner.Client.from_service_account_json(
credentials, project=project, client_info=client_info
credentials,
project=project,
client_info=client_info,
route_to_leader_enabled=False,
)
else:
client = spanner.Client(
project=project, credentials=credentials, client_info=client_info
project=project,
credentials=credentials,
client_info=client_info,
route_to_leader_enabled=False,
)
else:
if project is not None and client.project != project:
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,15 @@ def _metadata_with_prefix(prefix, **kw):
List[Tuple[str, str]]: RPC metadata with supplied prefix
"""
return [("google-cloud-resource-prefix", prefix)]


def _metadata_with_leader_aware_routing(value, **kw):
"""Create RPC metadata containing a leader aware routing header
Args:
value (bool): header value
Returns:
List[Tuple[str, str]]: RPC metadata with leader aware routing header
"""
return ("x-goog-spanner-route-to-leader", str(value).lower())
9 changes: 8 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._helpers import _make_list_value_pbs
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions

Expand Down Expand Up @@ -159,6 +162,10 @@ def commit(self, return_commit_stats=False, request_options=None):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
trace_attributes = {"num_mutations": len(self._mutations)}

Expand Down
19 changes: 19 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ class Client(ClientWithProject):
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.QueryOptions`
:type route_to_leader_enabled: boolean
:param route_to_leader_enabled:
(Optional) Default False. Set route_to_leader_enabled as True to
Enable leader aware routing. Enabling leader aware routing
would route all requests in RW/PDML transactions to the
leader region.
:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
and ``admin`` are :data:`True`
"""
Expand All @@ -132,6 +139,7 @@ def __init__(
client_info=_CLIENT_INFO,
client_options=None,
query_options=None,
route_to_leader_enabled=False,
):
self._emulator_host = _get_spanner_emulator_host()

Expand Down Expand Up @@ -171,6 +179,8 @@ def __init__(
):
warnings.warn(_EMULATOR_HOST_HTTP_SCHEME)

self._route_to_leader_enabled = route_to_leader_enabled

@property
def credentials(self):
"""Getter for client's credentials.
Expand Down Expand Up @@ -242,6 +252,15 @@ def database_admin_api(self):
)
return self._database_admin_api

@property
def route_to_leader_enabled(self):
"""Getter for if read-write or pdml requests will be routed to leader.
:rtype: boolean
:returns: If read-write requests will be routed to leader.
"""
return self._route_to_leader_enabled

def copy(self):
"""Make a copy of this client.
Expand Down
10 changes: 9 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1 import SpannerClient
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.pool import BurstyPool
Expand Down Expand Up @@ -155,6 +158,7 @@ def __init__(
self._encryption_config = encryption_config
self._database_dialect = database_dialect
self._database_role = database_role
self._route_to_leader_enabled = self._instance._client.route_to_leader_enabled

if pool is None:
pool = BurstyPool(database_role=database_role)
Expand Down Expand Up @@ -565,6 +569,10 @@ def execute_partitioned_dml(
)

metadata = _metadata_with_prefix(self.name)
if self._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(self._route_to_leader_enabled)
)

def execute_pdml():
with SessionCheckout(self._pool) as session:
Expand Down
13 changes: 12 additions & 1 deletion google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from google.cloud.exceptions import NotFound
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
from google.cloud.spanner_v1 import Session
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from warnings import warn

_NOW = datetime.datetime.utcnow # unit tests may replace
Expand Down Expand Up @@ -191,6 +194,10 @@ def bind(self, database):
self._database = database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
self._database_role = self._database_role or self._database.database_role
request = BatchCreateSessionsRequest(
database=database.name,
Expand Down Expand Up @@ -402,6 +409,10 @@ def bind(self, database):
self._database = database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
created_session_count = 0
self._database_role = self._database_role or self._database.database_role

Expand Down
17 changes: 16 additions & 1 deletion google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import CreateSessionRequest
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.snapshot import Snapshot
Expand Down Expand Up @@ -125,6 +128,12 @@ def create(self):
raise ValueError("Session ID already set by back-end")
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
if self._database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
self._database._route_to_leader_enabled
)
)

request = CreateSessionRequest(database=self._database.name)
if self._database.database_role is not None:
Expand Down Expand Up @@ -153,6 +162,12 @@ def exists(self):
return False
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
if self._database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
self._database._route_to_leader_enabled
)
)

with trace_call("CloudSpanner.GetSession", self) as span:
try:
Expand Down
29 changes: 26 additions & 3 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from google.api_core import gapic_v1
from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.streamed import StreamedResultSet
Expand Down Expand Up @@ -235,6 +238,10 @@ def read(
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if not self._read_only and database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

if request_options is None:
request_options = RequestOptions()
Expand All @@ -244,7 +251,7 @@ def read(
if self._read_only:
# Transaction tags are not supported for read only transactions.
request_options.transaction_tag = None
else:
elif self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

request = ReadRequest(
Expand Down Expand Up @@ -391,6 +398,10 @@ def execute_sql(

database = self._session._database
metadata = _metadata_with_prefix(database.name)
if not self._read_only and database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

api = database.spanner_api

Expand All @@ -406,7 +417,7 @@ def execute_sql(
if self._read_only:
# Transaction tags are not supported for read only transactions.
request_options.transaction_tag = None
else:
elif self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

request = ExecuteSqlRequest(
Expand Down Expand Up @@ -527,6 +538,10 @@ def partition_read(
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
transaction = self._make_txn_selector()
partition_options = PartitionOptions(
partition_size_bytes=partition_size_bytes, max_partitions=max_partitions
Expand Down Expand Up @@ -621,6 +636,10 @@ def partition_query(
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
transaction = self._make_txn_selector()
partition_options = PartitionOptions(
partition_size_bytes=partition_size_bytes, max_partitions=max_partitions
Expand Down Expand Up @@ -766,6 +785,10 @@ def begin(self):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if not self._read_only and database._route_to_leader_enabled:
metadata.append(
(_metadata_with_leader_aware_routing(database._route_to_leader_enabled))
)
txn_selector = self._make_txn_selector()
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
Expand Down
24 changes: 24 additions & 0 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_make_value_pb,
_merge_query_options,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
Expand Down Expand Up @@ -50,6 +51,7 @@ class Transaction(_SnapshotBase, _BatchBase):
_multi_use = True
_execute_sql_count = 0
_lock = threading.Lock()
_read_only = False

def __init__(self, session):
if session._transaction is not None:
Expand Down Expand Up @@ -124,6 +126,10 @@ def begin(self):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
Expand All @@ -140,6 +146,12 @@ def rollback(self):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
database._route_to_leader_enabled
)
)
with trace_call("CloudSpanner.Rollback", self._session):
api.rollback(
session=self._session.name,
Expand Down Expand Up @@ -176,6 +188,10 @@ def commit(self, return_commit_stats=False, request_options=None):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
trace_attributes = {"num_mutations": len(self._mutations)}

if request_options is None:
Expand Down Expand Up @@ -294,6 +310,10 @@ def execute_update(
params_pb = self._make_params_pb(params, param_types)
database = self._session._database
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
api = database.spanner_api

seqno, self._execute_sql_count = (
Expand Down Expand Up @@ -406,6 +426,10 @@ def batch_update(self, statements, request_options=None):

database = self._session._database
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
api = database.spanner_api

seqno, self._execute_sql_count = (
Expand Down
Loading

0 comments on commit f9fefad

Please sign in to comment.