diff --git a/kubernetes/base/watch/watch.py b/kubernetes/base/watch/watch.py index e8fe6c63e6..e703ca87ab 100644 --- a/kubernetes/base/watch/watch.py +++ b/kubernetes/base/watch/watch.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time import json import pydoc import sys @@ -180,10 +181,22 @@ def stream(self, func, *args, **kwargs): disable_retries = ('timeout_seconds' in kwargs) retry_after_410 = False deserialize = kwargs.pop('deserialize', True) + + health_check_interval = kwargs.pop('_health_check_interval', 0) # 0 = disabled by default + last_event_time = time.time() if health_check_interval > 0 else None + while True: resp = func(*args, **kwargs) try: for line in iter_resp_lines(resp): + # Health check for silent connection drops + if health_check_interval > 0 and last_event_time is not None: + current_time = time.time() + if current_time - last_event_time > health_check_interval: + # Silent connection detected - break to reconnect + break + last_event_time = current_time + # unmarshal when we are receiving events from watch, # return raw string when we are streaming log if watch_arg == "watch": diff --git a/kubernetes/base/watch/watch_test.py b/kubernetes/base/watch/watch_test.py index 4907dd5433..04c542e09b 100644 --- a/kubernetes/base/watch/watch_test.py +++ b/kubernetes/base/watch/watch_test.py @@ -576,6 +576,36 @@ def test_pod_log_empty_lines(self): self.api.delete_namespaced_pod(name=pod_name, namespace=self.namespace) self.api.delete_namespaced_pod.assert_called_once_with(name=pod_name, namespace=self.namespace) + def test_health_check_detects_silent_connection_drop(self): + """Test that health check detects when connection stops receiving events""" + fake_resp = Mock() + fake_resp.close = Mock() + fake_resp.release_conn = Mock() + + def limited_stalled_stream(): + yield '{"type": "ADDED", "object": {"metadata": {"name": "test1", "resourceVersion": "1"}}}\n' + for _ in range(10): + yield '' + return + + fake_resp.stream = Mock(return_value=limited_stalled_stream()) + + fake_api = Mock() + fake_api.get_namespaces = Mock(return_value=fake_resp) + fake_api.get_namespaces.__doc__ = ':return: V1NamespaceList' + + w = Watch() + events = [] + + try: + for e in w.stream(fake_api.get_namespaces, _health_check_interval=0.1, timeout_seconds=1): + events.append(e) + except Exception: + pass + + self.assertEqual(1, len(events)) + self.assertEqual("test1", events[0]['object'].metadata.name) + # Comment out the test below, it does not work currently. # def test_watch_with_deserialize_param(self): # """test watch.stream() deserialize param"""