From 059cfffce4b6d155e83cf137f47439121b1af8f5 Mon Sep 17 00:00:00 2001 From: Subham Sinha Date: Tue, 16 Sep 2025 22:16:30 +0530 Subject: [PATCH] feat(spanner): add lazy decode to partitioned query --- google/cloud/spanner_v1/database.py | 29 ++++- google/cloud/spanner_v1/merged_result_set.py | 41 ++++++- tests/unit/test_database.py | 15 +++ tests/unit/test_merged_result_set.py | 119 +++++++++++++++++++ 4 files changed, 199 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_merged_result_set.py diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 215cd5bed8..c5fc56bcc9 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -1532,6 +1532,14 @@ def to_dict(self): "transaction_id": snapshot._transaction_id, } + def __enter__(self): + """Begin ``with`` block.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """End ``with`` block.""" + self.close() + @property def observability_options(self): return getattr(self._database, "observability_options", {}) @@ -1703,6 +1711,7 @@ def process_read_batch( *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + lazy_decode=False, ): """Process a single, partitioned read. @@ -1717,6 +1726,14 @@ def process_read_batch( :type timeout: float :param timeout: (Optional) The timeout for this request. + :type lazy_decode: bool + :param lazy_decode: + (Optional) If this argument is set to ``true``, the iterator + returns the underlying protobuf values instead of decoded Python + objects. This reduces the time that is needed to iterate through + large result sets. The application is responsible for decoding + the data that is needed. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. @@ -1844,6 +1861,7 @@ def process_query_batch( self, batch, *, + lazy_decode: bool = False, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ): @@ -1854,6 +1872,13 @@ def process_query_batch( one of the mappings returned from an earlier call to :meth:`generate_query_batches`. + :type lazy_decode: bool + :param lazy_decode: + (Optional) If this argument is set to ``true``, the iterator + returns the underlying protobuf values instead of decoded Python + objects. This reduces the time that is needed to iterate through + large result sets. + :type retry: :class:`~google.api_core.retry.Retry` :param retry: (Optional) The retry settings for this request. @@ -1870,6 +1895,7 @@ def process_query_batch( return self._get_snapshot().execute_sql( partition=batch["partition"], **batch["query"], + lazy_decode=lazy_decode, retry=retry, timeout=timeout, ) @@ -1883,6 +1909,7 @@ def run_partitioned_query( max_partitions=None, query_options=None, data_boost_enabled=False, + lazy_decode=False, ): """Start a partitioned query operation to get list of partitions and then executes each partition on a separate thread @@ -1943,7 +1970,7 @@ def run_partitioned_query( data_boost_enabled, ) ) - return MergedResultSet(self, partitions, 0) + return MergedResultSet(self, partitions, 0, lazy_decode=lazy_decode) def process(self, batch): """Process a single, partitioned query or read. diff --git a/google/cloud/spanner_v1/merged_result_set.py b/google/cloud/spanner_v1/merged_result_set.py index 7af989d696..6c5c792246 100644 --- a/google/cloud/spanner_v1/merged_result_set.py +++ b/google/cloud/spanner_v1/merged_result_set.py @@ -33,10 +33,13 @@ class PartitionExecutor: rows in the queue """ - def __init__(self, batch_snapshot, partition_id, merged_result_set): + def __init__( + self, batch_snapshot, partition_id, merged_result_set, lazy_decode=False + ): self._batch_snapshot: BatchSnapshot = batch_snapshot self._partition_id = partition_id self._merged_result_set: MergedResultSet = merged_result_set + self._lazy_decode = lazy_decode self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue def run(self): @@ -52,7 +55,9 @@ def run(self): def __run(self): results = None try: - results = self._batch_snapshot.process_query_batch(self._partition_id) + results = self._batch_snapshot.process_query_batch( + self._partition_id, lazy_decode=self._lazy_decode + ) for row in results: if self._merged_result_set._metadata is None: self._set_metadata(results) @@ -75,6 +80,7 @@ def _set_metadata(self, results, is_exception=False): try: if not is_exception: self._merged_result_set._metadata = results.metadata + self._merged_result_set._result_set = results finally: self._merged_result_set.metadata_lock.release() self._merged_result_set.metadata_event.set() @@ -94,7 +100,10 @@ class MergedResultSet: records in the MergedResultSet is not guaranteed. """ - def __init__(self, batch_snapshot, partition_ids, max_parallelism): + def __init__( + self, batch_snapshot, partition_ids, max_parallelism, lazy_decode=False + ): + self._result_set = None self._exception = None self._metadata = None self.metadata_event = Event() @@ -110,7 +119,7 @@ def __init__(self, batch_snapshot, partition_ids, max_parallelism): partition_executors = [] for partition_id in partition_ids: partition_executors.append( - PartitionExecutor(batch_snapshot, partition_id, self) + PartitionExecutor(batch_snapshot, partition_id, self, lazy_decode) ) executor = ThreadPoolExecutor(max_workers=parallelism) for partition_executor in partition_executors: @@ -144,3 +153,27 @@ def metadata(self): def stats(self): # TODO: Implement return None + + def decode_row(self, row: []) -> []: + """Decodes a row from protobuf values to Python objects. This function + should only be called for result sets that use ``lazy_decoding=True``. + The array that is returned by this function is the same as the array + that would have been returned by the rows iterator if ``lazy_decoding=False``. + + :returns: an array containing the decoded values of all the columns in the given row + """ + if self._result_set is None: + raise ValueError("iterator not started") + return self._result_set.decode_row(row) + + def decode_column(self, row: [], column_index: int): + """Decodes a column from a protobuf value to a Python object. This function + should only be called for result sets that use ``lazy_decoding=True``. + The object that is returned by this function is the same as the object + that would have been returned by the rows iterator if ``lazy_decoding=False``. + + :returns: the decoded column value + """ + if self._result_set is None: + raise ValueError("iterator not started") + return self._result_set.decode_column(row, column_index) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 1c7f58c4ab..fa6792b9da 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -3141,6 +3141,7 @@ def test_process_query_batch(self): params=params, param_types=param_types, partition=token, + lazy_decode=False, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ) @@ -3170,6 +3171,7 @@ def test_process_query_batch_w_retry_timeout(self): params=params, param_types=param_types, partition=token, + lazy_decode=False, retry=retry, timeout=2.0, ) @@ -3193,11 +3195,23 @@ def test_process_query_batch_w_directed_read_options(self): snapshot.execute_sql.assert_called_once_with( sql=sql, partition=token, + lazy_decode=False, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, directed_read_options=DIRECTED_READ_OPTIONS, ) + def test_context_manager(self): + database = self._make_database() + batch_txn = self._make_one(database) + session = batch_txn._session = self._make_session() + session.is_multiplexed = False + + with batch_txn: + pass + + session.delete.assert_called_once_with() + def test_close_wo_session(self): database = self._make_database() batch_txn = self._make_one(database) @@ -3292,6 +3306,7 @@ def test_process_w_query_batch(self): params=params, param_types=param_types, partition=token, + lazy_decode=False, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ) diff --git a/tests/unit/test_merged_result_set.py b/tests/unit/test_merged_result_set.py new file mode 100644 index 0000000000..99fe50765e --- /dev/null +++ b/tests/unit/test_merged_result_set.py @@ -0,0 +1,119 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import mock +from google.cloud.spanner_v1.streamed import StreamedResultSet + + +class TestMergedResultSet(unittest.TestCase): + def _get_target_class(self): + from google.cloud.spanner_v1.merged_result_set import MergedResultSet + + return MergedResultSet + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + obj = super(klass, klass).__new__(klass) + from threading import Event, Lock + + obj.metadata_event = Event() + obj.metadata_lock = Lock() + obj._metadata = None + obj._result_set = None + return obj + + @staticmethod + def _make_value(value): + from google.cloud.spanner_v1._helpers import _make_value_pb + + return _make_value_pb(value) + + @staticmethod + def _make_scalar_field(name, type_): + from google.cloud.spanner_v1 import StructType + from google.cloud.spanner_v1 import Type + + return StructType.Field(name=name, type_=Type(code=type_)) + + @staticmethod + def _make_result_set_metadata(fields=()): + from google.cloud.spanner_v1 import ResultSetMetadata + from google.cloud.spanner_v1 import StructType + + metadata = ResultSetMetadata(row_type=StructType(fields=[])) + for field in fields: + metadata.row_type.fields.append(field) + return metadata + + def test_stats_property(self): + merged = self._make_one() + # The property is currently not implemented, so it should just return None. + self.assertIsNone(merged.stats) + + def test_decode_row(self): + merged = self._make_one() + + merged._result_set = mock.create_autospec(StreamedResultSet, instance=True) + merged._result_set.decode_row.return_value = ["Phred", 42] + + raw_row = [self._make_value("Phred"), self._make_value(42)] + decoded_row = merged.decode_row(raw_row) + + self.assertEqual(decoded_row, ["Phred", 42]) + merged._result_set.decode_row.assert_called_once_with(raw_row) + + def test_decode_row_no_result_set(self): + merged = self._make_one() + merged._result_set = None + with self.assertRaisesRegex(ValueError, "iterator not started"): + merged.decode_row([]) + + def test_decode_row_type_error(self): + merged = self._make_one() + merged._result_set = mock.create_autospec(StreamedResultSet, instance=True) + merged._result_set.decode_row.side_effect = TypeError + + with self.assertRaises(TypeError): + merged.decode_row("not a list") + + def test_decode_column(self): + merged = self._make_one() + merged._result_set = mock.create_autospec(StreamedResultSet, instance=True) + merged._result_set.decode_column.side_effect = ["Phred", 42] + + raw_row = [self._make_value("Phred"), self._make_value(42)] + decoded_name = merged.decode_column(raw_row, 0) + decoded_age = merged.decode_column(raw_row, 1) + + self.assertEqual(decoded_name, "Phred") + self.assertEqual(decoded_age, 42) + merged._result_set.decode_column.assert_has_calls( + [mock.call(raw_row, 0), mock.call(raw_row, 1)] + ) + + def test_decode_column_no_result_set(self): + merged = self._make_one() + merged._result_set = None + with self.assertRaisesRegex(ValueError, "iterator not started"): + merged.decode_column([], 0) + + def test_decode_column_type_error(self): + merged = self._make_one() + merged._result_set = mock.create_autospec(StreamedResultSet, instance=True) + merged._result_set.decode_column.side_effect = TypeError + + with self.assertRaises(TypeError): + merged.decode_column("not a list", 0)