Skip to content

Commit

Permalink
feat: Batch Write API implementation and samples (#1027)
Browse files Browse the repository at this point in the history
* feat: Batch Write API implementation and samples

* Update sample

* review comments

* return public class for mutation groups

* Update google/cloud/spanner_v1/batch.py

Co-authored-by: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com>

* Update google/cloud/spanner_v1/batch.py

Co-authored-by: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com>

* review comments

* remove doc

* feat(spanner): nit sample data refactoring

* review comments

* fix test

---------

Co-authored-by: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com>
Co-authored-by: Sri Harsha CH <sriharshach@google.com>
  • Loading branch information
3 people committed Dec 3, 2023
1 parent 7debe71 commit aa36b07
Show file tree
Hide file tree
Showing 9 changed files with 584 additions and 0 deletions.
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .types.result_set import ResultSetStats
from .types.spanner import BatchCreateSessionsRequest
from .types.spanner import BatchCreateSessionsResponse
from .types.spanner import BatchWriteRequest
from .types.spanner import BatchWriteResponse
from .types.spanner import BeginTransactionRequest
from .types.spanner import CommitRequest
from .types.spanner import CreateSessionRequest
Expand Down Expand Up @@ -99,6 +101,8 @@
# google.cloud.spanner_v1.types
"BatchCreateSessionsRequest",
"BatchCreateSessionsResponse",
"BatchWriteRequest",
"BatchWriteResponse",
"BeginTransactionRequest",
"CommitRequest",
"CommitResponse",
Expand Down
94 changes: 94 additions & 0 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1 import BatchWriteRequest

from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._helpers import _make_list_value_pbs
Expand Down Expand Up @@ -215,6 +216,99 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.commit()


class MutationGroup(_BatchBase):
"""A container for mutations.
Clients should use :class:`~google.cloud.spanner_v1.MutationGroups` to
obtain instances instead of directly creating instances.
:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: The session used to perform the commit.
:type mutations: list
:param mutations: The list into which mutations are to be accumulated.
"""

def __init__(self, session, mutations=[]):
super(MutationGroup, self).__init__(session)
self._mutations = mutations


class MutationGroups(_SessionWrapper):
"""Accumulate mutation groups for transmission during :meth:`batch_write`.
:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: the session used to perform the commit
"""

committed = None

def __init__(self, session):
super(MutationGroups, self).__init__(session)
self._mutation_groups = []

def _check_state(self):
"""Checks if the object's state is valid for making API requests.
:raises: :exc:`ValueError` if the object's state is invalid for making
API requests.
"""
if self.committed is not None:
raise ValueError("MutationGroups already committed")

def group(self):
"""Returns a new `MutationGroup` to which mutations can be added."""
mutation_group = BatchWriteRequest.MutationGroup()
self._mutation_groups.append(mutation_group)
return MutationGroup(self._session, mutation_group.mutations)

def batch_write(self, request_options=None):
"""Executes batch_write.
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]`
:returns: a sequence of responses for each batch.
"""
self._check_state()

database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
trace_attributes = {"num_mutation_groups": len(self._mutation_groups)}
if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)

request = BatchWriteRequest(
session=self._session.name,
mutation_groups=self._mutation_groups,
request_options=request_options,
)
with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes):
method = functools.partial(
api.batch_write,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = True
return response


def _make_write_pb(table, columns, values):
"""Helper for :meth:`Batch.insert` et al.
Expand Down
45 changes: 45 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
Expand Down Expand Up @@ -734,6 +735,17 @@ def batch(self, request_options=None):
"""
return BatchCheckout(self, request_options)

def mutation_groups(self):
"""Return an object which wraps a mutation_group.
The wrapper *must* be used as a context manager, with the mutation group
as the value returned by the wrapper.
:rtype: :class:`~google.cloud.spanner_v1.database.MutationGroupsCheckout`
:returns: new wrapper
"""
return MutationGroupsCheckout(self)

def batch_snapshot(self, read_timestamp=None, exact_staleness=None):
"""Return an object which wraps a batch read / query.
Expand Down Expand Up @@ -1040,6 +1052,39 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._database._pool.put(self._session)


class MutationGroupsCheckout(object):
"""Context manager for using mutation groups from a database.
Inside the context manager, checks out a session from the database,
creates mutation groups from it, making the groups available.
Caller must *not* use the object to perform API requests outside the scope
of the context manager.
:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: database to use
"""

def __init__(self, database):
self._database = database
self._session = None

def __enter__(self):
"""Begin ``with`` block."""
session = self._session = self._database._pool.get()
return MutationGroups(session)

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
if isinstance(exc_val, NotFound):
# If NotFound exception occurs inside the with block
# then we validate if the session still exists.
if not self._session.exists():
self._session = self._database._pool._new_session()
self._session.create()
self._database._pool.put(self._session)


class SnapshotCheckout(object):
"""Context manager for using a snapshot from a database.
Expand Down
62 changes: 62 additions & 0 deletions samples/samples/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,65 @@ def insert_data(instance_id, database_id):
# [END spanner_insert_data]


# [START spanner_batch_write_at_least_once]
def batch_write(instance_id, database_id):
"""Inserts sample data into the given database via BatchWrite API.
The database and table must already exist and can be created using
`create_database`.
"""
from google.rpc.code_pb2 import OK

spanner_client = spanner.Client()
instance = spanner_client.instance(instance_id)
database = instance.database(database_id)

with database.mutation_groups() as groups:
group1 = groups.group()
group1.insert_or_update(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(16, "Scarlet", "Terry"),
],
)

group2 = groups.group()
group2.insert_or_update(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(17, "Marc", ""),
(18, "Catalina", "Smith"),
],
)
group2.insert_or_update(
table="Albums",
columns=("SingerId", "AlbumId", "AlbumTitle"),
values=[
(17, 1, "Total Junk"),
(18, 2, "Go, Go, Go"),
],
)

for response in groups.batch_write():
if response.status.code == OK:
print(
"Mutation group indexes {} have been applied with commit timestamp {}".format(
response.indexes, response.commit_timestamp
)
)
else:
print(
"Mutation group indexes {} could not be applied with error {}".format(
response.indexes, response.status
)
)


# [END spanner_batch_write_at_least_once]


# [START spanner_delete_data]
def delete_data(instance_id, database_id):
"""Deletes sample data from the given database.
Expand Down Expand Up @@ -2677,6 +2736,7 @@ def drop_sequence(instance_id, database_id):
subparsers.add_parser("create_instance", help=create_instance.__doc__)
subparsers.add_parser("create_database", help=create_database.__doc__)
subparsers.add_parser("insert_data", help=insert_data.__doc__)
subparsers.add_parser("batch_write", help=batch_write.__doc__)
subparsers.add_parser("delete_data", help=delete_data.__doc__)
subparsers.add_parser("query_data", help=query_data.__doc__)
subparsers.add_parser("read_data", help=read_data.__doc__)
Expand Down Expand Up @@ -2811,6 +2871,8 @@ def drop_sequence(instance_id, database_id):
create_database(args.instance_id, args.database_id)
elif args.command == "insert_data":
insert_data(args.instance_id, args.database_id)
elif args.command == "batch_write":
batch_write(args.instance_id, args.database_id)
elif args.command == "delete_data":
delete_data(args.instance_id, args.database_id)
elif args.command == "query_data":
Expand Down
7 changes: 7 additions & 0 deletions samples/samples/snippets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ def test_insert_data(capsys, instance_id, sample_database):
assert "Inserted data" in out


@pytest.mark.dependency(name="batch_write")
def test_batch_write(capsys, instance_id, sample_database):
snippets.batch_write(instance_id, sample_database.database_id)
out, _ = capsys.readouterr()
assert "could not be applied with error" not in out


@pytest.mark.dependency(depends=["insert_data"])
def test_delete_data(capsys, instance_id, sample_database):
snippets.delete_data(instance_id, sample_database.database_id)
Expand Down
8 changes: 8 additions & 0 deletions tests/system/_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
(2, "Bharney", "Rhubble", "bharney@example.com"),
(3, "Wylma", "Phlyntstone", "wylma@example.com"),
)
BATCH_WRITE_ROW_DATA = (
(1, "Phred", "Phlyntstone", "phred@example.com"),
(2, "Bharney", "Rhubble", "bharney@example.com"),
(3, "Wylma", "Phlyntstone", "wylma@example.com"),
(4, "Pebbles", "Phlyntstone", "pebbles@example.com"),
(5, "Betty", "Rhubble", "betty@example.com"),
(6, "Slate", "Stephenson", "slate@example.com"),
)
ALL = spanner_v1.KeySet(all_=True)
SQL = "SELECT * FROM contacts ORDER BY contact_id"

Expand Down
35 changes: 35 additions & 0 deletions tests/system/test_session_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2521,6 +2521,41 @@ def test_partition_query(sessions_database, not_emulator):
batch_txn.close()


def test_mutation_groups_insert_or_update_then_query(not_emulator, sessions_database):
sd = _sample_data
num_groups = 3
num_mutations_per_group = len(sd.BATCH_WRITE_ROW_DATA) // num_groups

with sessions_database.batch() as batch:
batch.delete(sd.TABLE, sd.ALL)

with sessions_database.mutation_groups() as groups:
for i in range(num_groups):
group = groups.group()
for j in range(num_mutations_per_group):
group.insert_or_update(
sd.TABLE,
sd.COLUMNS,
[sd.BATCH_WRITE_ROW_DATA[i * num_mutations_per_group + j]],
)
# Response indexes received
seen = collections.Counter()
for response in groups.batch_write():
_check_batch_status(response.status.code)
assert response.commit_timestamp is not None
assert len(response.indexes) > 0
seen.update(response.indexes)
# All indexes must be in the range [0, num_groups-1] and seen exactly once
assert len(seen) == num_groups
assert all((0 <= idx < num_groups and ct == 1) for (idx, ct) in seen.items())

# Verify the writes by reading from the database
with sessions_database.snapshot() as snapshot:
rows = list(snapshot.execute_sql(sd.SQL))

sd._check_rows_data(rows, sd.BATCH_WRITE_ROW_DATA)


class FauxCall:
def __init__(self, code, details="FauxCall"):
self._code = code
Expand Down
Loading

0 comments on commit aa36b07

Please sign in to comment.