Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,14 @@ def to_dict(self):
"transaction_id": snapshot._transaction_id,
}

def __enter__(self):
"""Begin ``with`` block."""
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
self.close()

@property
def observability_options(self):
return getattr(self._database, "observability_options", {})
Expand Down Expand Up @@ -1703,6 +1711,7 @@ def process_read_batch(
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
lazy_decode=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add documentation for this new argument

):
"""Process a single, partitioned read.

Expand All @@ -1717,6 +1726,14 @@ def process_read_batch(
:type timeout: float
:param timeout: (Optional) The timeout for this request.

:type lazy_decode: bool
:param lazy_decode:
(Optional) If this argument is set to ``true``, the iterator
returns the underlying protobuf values instead of decoded Python
objects. This reduces the time that is needed to iterate through
large result sets. The application is responsible for decoding
the data that is needed.


:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
Expand Down Expand Up @@ -1844,6 +1861,7 @@ def process_query_batch(
self,
batch,
*,
lazy_decode: bool = False,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
):
Expand All @@ -1854,6 +1872,13 @@ def process_query_batch(
one of the mappings returned from an earlier call to
:meth:`generate_query_batches`.

:type lazy_decode: bool
:param lazy_decode:
(Optional) If this argument is set to ``true``, the iterator
returns the underlying protobuf values instead of decoded Python
objects. This reduces the time that is needed to iterate through
large result sets.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

Expand All @@ -1870,6 +1895,7 @@ def process_query_batch(
return self._get_snapshot().execute_sql(
partition=batch["partition"],
**batch["query"],
lazy_decode=lazy_decode,
retry=retry,
timeout=timeout,
)
Expand All @@ -1883,6 +1909,7 @@ def run_partitioned_query(
max_partitions=None,
query_options=None,
data_boost_enabled=False,
lazy_decode=False,
):
"""Start a partitioned query operation to get list of partitions and
then executes each partition on a separate thread
Expand Down Expand Up @@ -1943,7 +1970,7 @@ def run_partitioned_query(
data_boost_enabled,
)
)
return MergedResultSet(self, partitions, 0)
return MergedResultSet(self, partitions, 0, lazy_decode=lazy_decode)

def process(self, batch):
"""Process a single, partitioned query or read.
Expand Down
41 changes: 37 additions & 4 deletions google/cloud/spanner_v1/merged_result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ class PartitionExecutor:
rows in the queue
"""

def __init__(self, batch_snapshot, partition_id, merged_result_set):
def __init__(
self, batch_snapshot, partition_id, merged_result_set, lazy_decode=False
):
self._batch_snapshot: BatchSnapshot = batch_snapshot
self._partition_id = partition_id
self._merged_result_set: MergedResultSet = merged_result_set
self._lazy_decode = lazy_decode
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue

def run(self):
Expand All @@ -52,7 +55,9 @@ def run(self):
def __run(self):
results = None
try:
results = self._batch_snapshot.process_query_batch(self._partition_id)
results = self._batch_snapshot.process_query_batch(
self._partition_id, lazy_decode=self._lazy_decode
)
for row in results:
if self._merged_result_set._metadata is None:
self._set_metadata(results)
Expand All @@ -75,6 +80,7 @@ def _set_metadata(self, results, is_exception=False):
try:
if not is_exception:
self._merged_result_set._metadata = results.metadata
self._merged_result_set._result_set = results
finally:
self._merged_result_set.metadata_lock.release()
self._merged_result_set.metadata_event.set()
Expand All @@ -94,7 +100,10 @@ class MergedResultSet:
records in the MergedResultSet is not guaranteed.
"""

def __init__(self, batch_snapshot, partition_ids, max_parallelism):
def __init__(
self, batch_snapshot, partition_ids, max_parallelism, lazy_decode=False
):
self._result_set = None
self._exception = None
self._metadata = None
self.metadata_event = Event()
Expand All @@ -110,7 +119,7 @@ def __init__(self, batch_snapshot, partition_ids, max_parallelism):
partition_executors = []
for partition_id in partition_ids:
partition_executors.append(
PartitionExecutor(batch_snapshot, partition_id, self)
PartitionExecutor(batch_snapshot, partition_id, self, lazy_decode)
)
executor = ThreadPoolExecutor(max_workers=parallelism)
for partition_executor in partition_executors:
Expand Down Expand Up @@ -144,3 +153,27 @@ def metadata(self):
def stats(self):
# TODO: Implement
return None

def decode_row(self, row: []) -> []:
"""Decodes a row from protobuf values to Python objects. This function
should only be called for result sets that use ``lazy_decoding=True``.
The array that is returned by this function is the same as the array
that would have been returned by the rows iterator if ``lazy_decoding=False``.

:returns: an array containing the decoded values of all the columns in the given row
"""
if self._result_set is None:
raise ValueError("iterator not started")
return self._result_set.decode_row(row)

def decode_column(self, row: [], column_index: int):
"""Decodes a column from a protobuf value to a Python object. This function
should only be called for result sets that use ``lazy_decoding=True``.
The object that is returned by this function is the same as the object
that would have been returned by the rows iterator if ``lazy_decoding=False``.

:returns: the decoded column value
"""
if self._result_set is None:
raise ValueError("iterator not started")
return self._result_set.decode_column(row, column_index)
15 changes: 15 additions & 0 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3141,6 +3141,7 @@ def test_process_query_batch(self):
params=params,
param_types=param_types,
partition=token,
lazy_decode=False,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
)
Expand Down Expand Up @@ -3170,6 +3171,7 @@ def test_process_query_batch_w_retry_timeout(self):
params=params,
param_types=param_types,
partition=token,
lazy_decode=False,
retry=retry,
timeout=2.0,
)
Expand All @@ -3193,11 +3195,23 @@ def test_process_query_batch_w_directed_read_options(self):
snapshot.execute_sql.assert_called_once_with(
sql=sql,
partition=token,
lazy_decode=False,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
directed_read_options=DIRECTED_READ_OPTIONS,
)

def test_context_manager(self):
database = self._make_database()
batch_txn = self._make_one(database)
session = batch_txn._session = self._make_session()
session.is_multiplexed = False

with batch_txn:
pass

session.delete.assert_called_once_with()

def test_close_wo_session(self):
database = self._make_database()
batch_txn = self._make_one(database)
Expand Down Expand Up @@ -3292,6 +3306,7 @@ def test_process_w_query_batch(self):
params=params,
param_types=param_types,
partition=token,
lazy_decode=False,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
)
Expand Down
119 changes: 119 additions & 0 deletions tests/unit/test_merged_result_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import mock
from google.cloud.spanner_v1.streamed import StreamedResultSet


class TestMergedResultSet(unittest.TestCase):
def _get_target_class(self):
from google.cloud.spanner_v1.merged_result_set import MergedResultSet

return MergedResultSet

def _make_one(self, *args, **kwargs):
klass = self._get_target_class()
obj = super(klass, klass).__new__(klass)
from threading import Event, Lock

obj.metadata_event = Event()
obj.metadata_lock = Lock()
obj._metadata = None
obj._result_set = None
return obj

@staticmethod
def _make_value(value):
from google.cloud.spanner_v1._helpers import _make_value_pb

return _make_value_pb(value)

@staticmethod
def _make_scalar_field(name, type_):
from google.cloud.spanner_v1 import StructType
from google.cloud.spanner_v1 import Type

return StructType.Field(name=name, type_=Type(code=type_))

@staticmethod
def _make_result_set_metadata(fields=()):
from google.cloud.spanner_v1 import ResultSetMetadata
from google.cloud.spanner_v1 import StructType

metadata = ResultSetMetadata(row_type=StructType(fields=[]))
for field in fields:
metadata.row_type.fields.append(field)
return metadata

def test_stats_property(self):
merged = self._make_one()
# The property is currently not implemented, so it should just return None.
self.assertIsNone(merged.stats)

def test_decode_row(self):
merged = self._make_one()

merged._result_set = mock.create_autospec(StreamedResultSet, instance=True)
merged._result_set.decode_row.return_value = ["Phred", 42]

raw_row = [self._make_value("Phred"), self._make_value(42)]
decoded_row = merged.decode_row(raw_row)

self.assertEqual(decoded_row, ["Phred", 42])
merged._result_set.decode_row.assert_called_once_with(raw_row)

def test_decode_row_no_result_set(self):
merged = self._make_one()
merged._result_set = None
with self.assertRaisesRegex(ValueError, "iterator not started"):
merged.decode_row([])

def test_decode_row_type_error(self):
merged = self._make_one()
merged._result_set = mock.create_autospec(StreamedResultSet, instance=True)
merged._result_set.decode_row.side_effect = TypeError

with self.assertRaises(TypeError):
merged.decode_row("not a list")

def test_decode_column(self):
merged = self._make_one()
merged._result_set = mock.create_autospec(StreamedResultSet, instance=True)
merged._result_set.decode_column.side_effect = ["Phred", 42]

raw_row = [self._make_value("Phred"), self._make_value(42)]
decoded_name = merged.decode_column(raw_row, 0)
decoded_age = merged.decode_column(raw_row, 1)

self.assertEqual(decoded_name, "Phred")
self.assertEqual(decoded_age, 42)
merged._result_set.decode_column.assert_has_calls(
[mock.call(raw_row, 0), mock.call(raw_row, 1)]
)

def test_decode_column_no_result_set(self):
merged = self._make_one()
merged._result_set = None
with self.assertRaisesRegex(ValueError, "iterator not started"):
merged.decode_column([], 0)

def test_decode_column_type_error(self):
merged = self._make_one()
merged._result_set = mock.create_autospec(StreamedResultSet, instance=True)
merged._result_set.decode_column.side_effect = TypeError

with self.assertRaises(TypeError):
merged.decode_column("not a list", 0)