Skip to content

Commit

Permalink
PYTHON-3186 Avoid SDAM heartbeat timeouts on AWS Lambda (#912)
Browse files Browse the repository at this point in the history
Poll monitor socket with timeout=0 one last time after timeout expires.
This avoids heartbeat timeouts and connection churn on Lambda and other FaaS envs.
  • Loading branch information
ShaneHarvey committed Mar 30, 2022
1 parent 1d30802 commit c58950a
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 5 deletions.
11 changes: 9 additions & 2 deletions pymongo/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def wait_for_read(sock_info, deadline):
# Only Monitor connections can be cancelled.
if context:
sock = sock_info.sock
timed_out = False
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
Expand All @@ -252,15 +253,21 @@ def wait_for_read(sock_info, deadline):
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
timeout = max(min(deadline - time.monotonic(), _POLL_TIMEOUT), 0.001)
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout)
if context.cancelled:
raise _OperationCancelled("hello cancelled")
if readable:
return
if deadline and time.monotonic() > deadline:
if timed_out:
raise socket.timeout("timed out")


Expand Down
23 changes: 20 additions & 3 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from test.version import Version
from typing import Dict, no_type_check
from unittest import SkipTest
from urllib.parse import quote_plus

import pymongo
import pymongo.errors
Expand Down Expand Up @@ -279,6 +280,22 @@ def client_options(self):
opts["replicaSet"] = self.replica_set_name
return opts

@property
def uri(self):
"""Return the MongoClient URI for creating a duplicate client."""
opts = client_context.default_client_options.copy()
opts_parts = []
for opt, val in opts.items():
strval = str(val)
if isinstance(val, bool):
strval = strval.lower()
opts_parts.append(f"{opt}={quote_plus(strval)}")
opts_part = "&".join(opts_parts)
auth_part = ""
if client_context.auth_enabled:
auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@"
return f"mongodb://{auth_part}{self.pair}/?{opts_part}"

@property
def hello(self):
if not self._hello:
Expand Down Expand Up @@ -359,7 +376,7 @@ def _init_client(self):
username=db_user,
password=db_pwd,
replicaSet=self.replica_set_name,
**self.default_client_options
**self.default_client_options,
)

# May not have this if OperationFailure was raised earlier.
Expand Down Expand Up @@ -387,7 +404,7 @@ def _init_client(self):
username=db_user,
password=db_pwd,
replicaSet=self.replica_set_name,
**self.default_client_options
**self.default_client_options,
)
else:
self.client = pymongo.MongoClient(
Expand Down Expand Up @@ -490,7 +507,7 @@ def _check_user_provided(self):
username=db_user,
password=db_pwd,
serverSelectionTimeoutMS=100,
**self.default_client_options
**self.default_client_options,
)

try:
Expand Down
85 changes: 85 additions & 0 deletions test/sigstop_sigcont.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2022-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Used by test_client.TestClient.test_sigstop_sigcont."""

import logging
import sys

sys.path[0:0] = [""]

from pymongo import monitoring
from pymongo.mongo_client import MongoClient


class HeartbeatLogger(monitoring.ServerHeartbeatListener):
"""Log events until the listener is closed."""

def __init__(self):
self.closed = False

def close(self):
self.closed = True

def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
if self.closed:
return
logging.info("%s", event)

def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
if self.closed:
return
logging.info("%s", event)

def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
if self.closed:
return
logging.warning("%s", event)


def main(uri: str) -> None:
heartbeat_logger = HeartbeatLogger()
client = MongoClient(
uri,
event_listeners=[heartbeat_logger],
heartbeatFrequencyMS=500,
connectTimeoutMS=500,
)
client.admin.command("ping")
logging.info("TEST STARTED")
# test_sigstop_sigcont will SIGSTOP and SIGCONT this process in this loop.
while True:
try:
data = input('Type "q" to quit: ')
except EOFError:
break
if data == "q":
break
client.admin.command("ping")
logging.info("TEST COMPLETED")
heartbeat_logger.close()
client.close()


if __name__ == "__main__":
if len(sys.argv) != 2:
print("unknown or missing options")
print(f"usage: python3 {sys.argv[0]} 'mongodb://localhost'")
exit(1)

# Enable logs in this format:
# 2022-03-30 12:40:55,582 INFO <ServerHeartbeatStartedEvent ('localhost', 27017)>
FORMAT = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO)
main(sys.argv[1])
34 changes: 34 additions & 0 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import signal
import socket
import struct
import subprocess
import sys
import threading
import time
Expand Down Expand Up @@ -1688,6 +1689,39 @@ def test_srv_max_hosts_kwarg(self):
)
self.assertEqual(len(client.topology_description.server_descriptions()), 2)

@unittest.skipIf(
client_context.load_balancer or client_context.serverless,
"loadBalanced clients do not run SDAM",
)
@unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP")
def test_sigstop_sigcont(self):
test_dir = os.path.dirname(os.path.realpath(__file__))
script = os.path.join(test_dir, "sigstop_sigcont.py")
p = subprocess.Popen(
[sys.executable, script, client_context.uri],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
self.addCleanup(p.wait, timeout=1)
self.addCleanup(p.kill)
time.sleep(1)
# Stop the child, sleep for twice the streaming timeout
# (heartbeatFrequencyMS + connectTimeoutMS), and restart.
os.kill(p.pid, signal.SIGSTOP)
time.sleep(2)
os.kill(p.pid, signal.SIGCONT)
time.sleep(0.5)
# Tell the script to exit gracefully.
outs, _ = p.communicate(input=b"q\n", timeout=10)
self.assertTrue(outs)
log_output = outs.decode("utf-8")
self.assertIn("TEST STARTED", log_output)
self.assertIn("ServerHeartbeatStartedEvent", log_output)
self.assertIn("ServerHeartbeatSucceededEvent", log_output)
self.assertIn("TEST COMPLETED", log_output)
self.assertNotIn("ServerHeartbeatFailedEvent", log_output)


class TestExhaustCursor(IntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""
Expand Down

0 comments on commit c58950a

Please sign in to comment.