Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: named database support #439

Merged
merged 10 commits into from
Jun 21, 2023
6 changes: 3 additions & 3 deletions google/cloud/datastore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
The main concepts with this API are:

- :class:`~google.cloud.datastore.client.Client`
which represents a project (string) and namespace (string) bundled with
a connection and has convenience methods for constructing objects with that
project / namespace.
which represents a project (string), database (string), and namespace
(string) bundled with a connection and has convenience methods for
constructing objects with that project/database/namespace.

- :class:`~google.cloud.datastore.entity.Entity`
which represents a single entity in the datastore
Expand Down
38 changes: 37 additions & 1 deletion google/cloud/datastore/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _request(
data,
base_url,
client_info,
database=None,
retry=None,
timeout=None,
):
Expand All @@ -84,6 +85,9 @@ def _request(
:type client_info: :class:`google.api_core.client_info.ClientInfo`
:param client_info: used to generate user agent.

:type database: str
:param database: (Optional) The database to make the request for.

:type retry: :class:`google.api_core.retry.Retry`
:param retry: (Optional) retry policy for the request

Expand All @@ -101,6 +105,7 @@ def _request(
"User-Agent": user_agent,
connection_module.CLIENT_INFO_HEADER: user_agent,
}
_update_headers(headers, project, database)
api_url = build_api_url(project, method, base_url)

requester = http.request
Expand Down Expand Up @@ -136,6 +141,7 @@ def _rpc(
client_info,
request_pb,
response_pb_cls,
database=None,
retry=None,
timeout=None,
):
Expand Down Expand Up @@ -165,6 +171,9 @@ def _rpc(
:param response_pb_cls: The class used to unmarshall the response
protobuf.

:type database: str
:param database: (Optional) The database to make the request for.

:type retry: :class:`google.api_core.retry.Retry`
:param retry: (Optional) retry policy for the request

Expand All @@ -177,7 +186,7 @@ def _rpc(
req_data = request_pb._pb.SerializeToString()
kwargs = _make_retry_timeout_kwargs(retry, timeout)
response = _request(
http, project, method, req_data, base_url, client_info, **kwargs
http, project, method, req_data, base_url, client_info, database, **kwargs
)
return response_pb_cls.deserialize(response)

Expand Down Expand Up @@ -236,6 +245,7 @@ def lookup(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.LookupRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -245,6 +255,7 @@ def lookup(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.LookupResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -267,6 +278,7 @@ def run_query(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.RunQueryRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -276,6 +288,7 @@ def run_query(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.RunQueryResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -300,6 +313,7 @@ def run_aggregation_query(self, request, retry=None, timeout=None):
request, _datastore_pb2.RunAggregationQueryRequest
)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -309,6 +323,7 @@ def run_aggregation_query(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.RunAggregationQueryResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -331,6 +346,7 @@ def begin_transaction(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.BeginTransactionRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -340,6 +356,7 @@ def begin_transaction(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.BeginTransactionResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -362,6 +379,7 @@ def commit(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.CommitRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -371,6 +389,7 @@ def commit(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.CommitResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -393,6 +412,7 @@ def rollback(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.RollbackRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -402,6 +422,7 @@ def rollback(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.RollbackResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -424,6 +445,7 @@ def allocate_ids(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.AllocateIdsRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -433,6 +455,7 @@ def allocate_ids(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.AllocateIdsResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -455,6 +478,7 @@ def reserve_ids(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.ReserveIdsRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -464,6 +488,18 @@ def reserve_ids(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.ReserveIdsResponse,
database_id,
retry=retry,
timeout=timeout,
)


def _update_headers(headers, project_id, database_id=None):
"""Update the request headers.
Pass the project id, or optionally the database_id if provided.
"""
headers["x-goog-request-params"] = f"project_id={project_id}"
if database_id:
vishwarajanand marked this conversation as resolved.
Show resolved Hide resolved
headers[
"x-goog-request-params"
] = f"project_id={project_id}&database_id={database_id}"
16 changes: 16 additions & 0 deletions google/cloud/datastore/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.cloud.datastore_v1.types import query as query_pb2
from google.cloud.datastore import helpers
from google.cloud.datastore.query import _pb_from_query
from google.cloud.datastore.constants import DEFAULT_DATABASE


_NOT_FINISHED = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED
Expand Down Expand Up @@ -123,6 +124,18 @@ def project(self):
"""
return self._nested_query._project or self._client.project

@property
def database(self):
"""Get the database for this AggregationQuery.
:rtype: str
:returns: The database for the query.
"""
if self._nested_query._database or (
self._nested_query._database == DEFAULT_DATABASE
):
return self._nested_query._database
return self._client.database

@property
def namespace(self):
"""The nested query's namespace
Expand Down Expand Up @@ -376,6 +389,7 @@ def _next_page(self):

partition_id = entity_pb2.PartitionId(
project_id=self._aggregation_query.project,
database_id=self._aggregation_query.database,
namespace_id=self._aggregation_query.namespace,
)

Expand All @@ -390,6 +404,7 @@ def _next_page(self):
response_pb = self.client._datastore_api.run_aggregation_query(
request={
"project_id": self._aggregation_query.project,
"database_id": self._aggregation_query.database,
"partition_id": partition_id,
"read_options": read_options,
"aggregation_query": query_pb,
Expand All @@ -409,6 +424,7 @@ def _next_page(self):
response_pb = self.client._datastore_api.run_aggregation_query(
request={
"project_id": self._aggregation_query.project,
"database_id": self._aggregation_query.database,
"partition_id": partition_id,
"read_options": read_options,
"aggregation_query": query_pb,
Expand Down
23 changes: 23 additions & 0 deletions google/cloud/datastore/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from google.cloud.datastore import helpers
from google.cloud.datastore_v1.types import datastore as _datastore_pb2
from google.cloud.datastore.constants import DEFAULT_DATABASE


class Batch(object):
Expand Down Expand Up @@ -122,6 +123,15 @@ def project(self):
"""
return self._client.project

@property
def database(self):
"""Getter for database in which the batch will run.

:rtype: :class:`str`
:returns: The database in which the batch will run.
"""
return self._client.database

@property
def namespace(self):
"""Getter for namespace in which the batch will run.
Expand Down Expand Up @@ -218,6 +228,12 @@ def put(self, entity):
if self.project != entity.key.project:
raise ValueError("Key must be from same project as batch")

entity_key_database = entity.key.database
if entity_key_database is None:
entity_key_database = DEFAULT_DATABASE
if self.database != entity_key_database:
raise ValueError("Key must be from same database as batch")

if entity.key.is_partial:
entity_pb = self._add_partial_key_entity_pb()
self._partial_key_entities.append(entity)
Expand Down Expand Up @@ -245,6 +261,12 @@ def delete(self, key):
if self.project != key.project:
raise ValueError("Key must be from same project as batch")

key_db = key.database
if key_db is None:
key_db = DEFAULT_DATABASE
if self.database != key_db:
raise ValueError("Key must be from same database as batch")

key_pb = key.to_protobuf()
self._add_delete_key_pb()._pb.CopyFrom(key_pb._pb)

Expand Down Expand Up @@ -284,6 +306,7 @@ def _commit(self, retry, timeout):
commit_response_pb = self._client._datastore_api.commit(
request={
"project_id": self.project,
"database_id": self.database,
"mode": mode,
"transaction": self._id,
"mutations": self._mutations,
Expand Down