Skip to content
Permalink
Browse files
fix: correctly set resume token when restarting streams (#314)
* fix: correctly set resume token for restarting streams

* style: fix lint

* docs: update docstring

* test: fix assertion

Co-authored-by: larkee <larkee@users.noreply.github.com>
  • Loading branch information
larkee and larkee committed Apr 26, 2021
1 parent 772aa3c commit 0fcfc2301246d3f20b6fbffc1deae06f16721ec7
Showing with 76 additions and 48 deletions.
  1. +3 −3 google/cloud/spanner_v1/database.py
  2. +19 −7 google/cloud/spanner_v1/snapshot.py
  3. +54 −38 tests/unit/test_snapshot.py
@@ -518,11 +518,11 @@ def execute_pdml():
param_types=param_types,
query_options=query_options,
)
restart = functools.partial(
api.execute_streaming_sql, request=request, metadata=metadata,
method = functools.partial(
api.execute_streaming_sql, metadata=metadata,
)

iterator = _restart_on_unavailable(restart)
iterator = _restart_on_unavailable(method, request)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
@@ -41,16 +41,21 @@
)


def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=None):
def _restart_on_unavailable(
method, request, trace_name=None, session=None, attributes=None
):
"""Restart iteration after :exc:`.ServiceUnavailable`.
:type restart: callable
:param restart: curried function returning iterator
:type method: callable
:param method: function returning iterator
:type request: proto
:param request: request proto to call the method with
"""
resume_token = b""
item_buffer = []
with trace_call(trace_name, session, attributes):
iterator = restart()
iterator = method(request=request)
while True:
try:
for item in iterator:
@@ -61,7 +66,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
except ServiceUnavailable:
del item_buffer[:]
with trace_call(trace_name, session, attributes):
iterator = restart(resume_token=resume_token)
request.resume_token = resume_token
iterator = method(request=request)
continue
except InternalServerError as exc:
resumable_error = any(
@@ -72,7 +78,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
raise
del item_buffer[:]
with trace_call(trace_name, session, attributes):
iterator = restart(resume_token=resume_token)
request.resume_token = resume_token
iterator = method(request=request)
continue

if len(item_buffer) == 0:
@@ -189,7 +196,11 @@ def read(

trace_attributes = {"table_id": table, "columns": columns}
iterator = _restart_on_unavailable(
restart, "CloudSpanner.ReadOnlyTransaction", self._session, trace_attributes
restart,
request,
"CloudSpanner.ReadOnlyTransaction",
self._session,
trace_attributes,
)

self._read_request_count += 1
@@ -302,6 +313,7 @@ def execute_sql(
trace_attributes = {"db.statement": sql}
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadWriteTransaction",
self._session,
trace_attributes,
@@ -47,10 +47,12 @@


class Test_restart_on_unavailable(OpenTelemetryBase):
def _call_fut(self, restart, span_name=None, session=None, attributes=None):
def _call_fut(
self, restart, request, span_name=None, session=None, attributes=None
):
from google.cloud.spanner_v1.snapshot import _restart_on_unavailable

return _restart_on_unavailable(restart, span_name, session, attributes)
return _restart_on_unavailable(restart, request, span_name, session, attributes)

def _make_item(self, value, resume_token=b""):
return mock.Mock(
@@ -59,18 +61,21 @@ def _make_item(self, value, resume_token=b""):

def test_iteration_w_empty_raw(self):
raw = _MockIterator()
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), [])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_non_empty_raw(self):
ITEMS = (self._make_item(0), self._make_item(1))
raw = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
restart.assert_called_once_with()
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_w_resume_tken(self):
@@ -81,10 +86,11 @@ def test_iteration_w_raw_w_resume_tken(self):
self._make_item(3),
)
raw = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
restart.assert_called_once_with()
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable_no_token(self):
@@ -97,10 +103,12 @@ def test_iteration_w_raw_raising_unavailable_no_token(self):
)
before = _MockIterator(fail_after=True, error=ServiceUnavailable("testing"))
after = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")])
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, b"")
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error_no_token(self):
@@ -118,10 +126,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self):
),
)
after = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")])
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, b"")
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self):
@@ -134,11 +144,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self):
)
before = _MockIterator(fail_after=True, error=InternalServerError("testing"))
after = _MockIterator(*ITEMS)
request = mock.Mock(spec=["resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable(self):
@@ -151,12 +162,12 @@ def test_iteration_w_raw_raising_unavailable(self):
*(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error(self):
@@ -173,12 +184,12 @@ def test_iteration_w_raw_raising_retryable_internal_error(self):
)
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error(self):
@@ -191,11 +202,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self):
*(FIRST + SECOND), fail_after=True, error=InternalServerError("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable_after_token(self):
@@ -207,12 +219,12 @@ def test_iteration_w_raw_raising_unavailable_after_token(self):
*FIRST, fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + SECOND))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error_after_token(self):
@@ -228,12 +240,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self):
)
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + SECOND))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self):
@@ -245,19 +257,23 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self):
*FIRST, fail_after=True, error=InternalServerError("testing")
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_span_creation(self):
name = "TestSpan"
extra_atts = {"test_att": 1}
raw = _MockIterator()
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart, name, _Session(_Database()), extra_atts)
resumable = self._call_fut(
restart, request, name, _Session(_Database()), extra_atts
)
self.assertEqual(list(resumable), [])
self.assertSpanAttributes(name, attributes=dict(BASE_ATTRIBUTES, test_att=1))

@@ -272,13 +288,13 @@ def test_iteration_w_multiple_span_creation(self):
*(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
name = "TestSpan"
resumable = self._call_fut(restart, name, _Session(_Database()))
resumable = self._call_fut(restart, request, name, _Session(_Database()))
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)

span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 2)

0 comments on commit 0fcfc23

Please sign in to comment.