Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 49 additions & 35 deletions pymongo/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

"""Class to monitor a MongoDB server on a background thread."""

from __future__ import annotations

import atexit
import time
import weakref
from typing import Any, Mapping, cast
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Tuple, cast

from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum
Expand All @@ -29,24 +31,29 @@
from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver

if TYPE_CHECKING:
from pymongo.pool import Connection, Pool
from pymongo.settings import TopologySettings
from pymongo.topology import Topology


def _sanitize(error):
def _sanitize(error: Exception) -> None:
"""PYTHON-2433 Clear error traceback info."""
error.__traceback__ = None
error.__context__ = None
error.__cause__ = None


class MonitorBase:
def __init__(self, topology, name, interval, min_interval):
def __init__(self, topology: Topology, name: str, interval: int, min_interval: float):
"""Base class to do periodic work on a background thread.

The background thread is signaled to stop when the Topology or
this instance is freed.
"""
# We strongly reference the executor and it weakly references us via
# this closure. When the monitor is freed, stop the executor soon.
def target():
def target() -> bool:
monitor = self_ref()
if monitor is None:
return False # Stop the executor.
Expand All @@ -59,7 +66,7 @@ def target():

self._executor = executor

def _on_topology_gc(dummy=None):
def _on_topology_gc(dummy: Optional[Topology] = None) -> None:
# This prevents GC from waiting 10 seconds for hello to complete
# See test_cleanup_executors_on_client_del.
monitor = self_ref()
Expand All @@ -71,35 +78,41 @@ def _on_topology_gc(dummy=None):
self._topology = weakref.proxy(topology, _on_topology_gc)
_register(self)

def open(self):
def open(self) -> None:
"""Start monitoring, or restart after a fork.

Multiple calls have no effect.
"""
self._executor.open()

def gc_safe_close(self):
def gc_safe_close(self) -> None:
"""GC safe close."""
self._executor.close()

def close(self):
def close(self) -> None:
"""Close and stop monitoring.

open() restarts the monitor after closing.
"""
self.gc_safe_close()

def join(self, timeout=None):
def join(self, timeout: Optional[int] = None) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)

def request_check(self):
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
self._executor.wake()


class Monitor(MonitorBase):
def __init__(self, server_description, topology, pool, topology_settings):
def __init__(
self,
server_description: ServerDescription,
topology: Topology,
pool: Pool,
topology_settings: TopologySettings,
):
"""Class to monitor a MongoDB server on a background thread.

Pass an initial ServerDescription, a Topology, a Pool, and
Expand Down Expand Up @@ -128,7 +141,7 @@ def __init__(self, server_description, topology, pool, topology_settings):
)
self.heartbeater = None

def cancel_check(self):
def cancel_check(self) -> None:
"""Cancel any concurrent hello check.

Note: this is called from a weakref.proxy callback and MUST NOT take
Expand All @@ -141,7 +154,7 @@ def cancel_check(self):
# (depending on the platform).
context.cancel()

def _start_rtt_monitor(self):
def _start_rtt_monitor(self) -> None:
"""Start an _RttMonitor that periodically runs ping."""
# If this monitor is closed directly before (or during) this open()
# call, the _RttMonitor will not be closed. Checking if this monitor
Expand All @@ -150,23 +163,23 @@ def _start_rtt_monitor(self):
if self._executor._stopped:
self._rtt_monitor.close()

def gc_safe_close(self):
def gc_safe_close(self) -> None:
self._executor.close()
self._rtt_monitor.gc_safe_close()
self.cancel_check()

def close(self):
def close(self) -> None:
self.gc_safe_close()
self._rtt_monitor.close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
self._reset_connection()

def _reset_connection(self):
def _reset_connection(self) -> None:
# Clear our pooled connection.
self._pool.reset()

def _run(self):
def _run(self) -> None:
try:
prev_sd = self._server_description
try:
Expand Down Expand Up @@ -203,7 +216,7 @@ def _run(self):
# Topology was garbage-collected.
self.close()

def _check_server(self):
def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.

Returns a ServerDescription.
Expand Down Expand Up @@ -234,7 +247,7 @@ def _check_server(self):
# Server type defaults to Unknown.
return ServerDescription(address, error=error)

def _check_once(self):
def _check_once(self) -> ServerDescription:
"""A single attempt to call hello.

Returns a ServerDescription, or raises an exception.
Expand All @@ -259,7 +272,7 @@ def _check_once(self):
)
return sd

def _check_with_socket(self, conn):
def _check_with_socket(self, conn: Connection) -> Tuple[Hello, float]:
"""Return (Hello, round_trip_time).

Can raise ConnectionFailure or OperationFailure.
Expand All @@ -283,7 +296,7 @@ def _check_with_socket(self, conn):


class SrvMonitor(MonitorBase):
def __init__(self, topology, topology_settings):
def __init__(self, topology: Topology, topology_settings: TopologySettings):
"""Class to poll SRV records on a background thread.

Pass a Topology and a TopologySettings.
Expand All @@ -298,9 +311,10 @@ def __init__(self, topology, topology_settings):
)
self._settings = topology_settings
self._seedlist = self._settings._seeds
self._fqdn = self._settings.fqdn
assert isinstance(self._settings.fqdn, str)
self._fqdn: str = self._settings.fqdn

def _run(self):
def _run(self) -> None:
seedlist = self._get_seedlist()
if seedlist:
self._seedlist = seedlist
Expand All @@ -310,7 +324,7 @@ def _run(self):
# Topology was garbage-collected.
self.close()

def _get_seedlist(self):
def _get_seedlist(self) -> Optional[List[Tuple[str, Any]]]:
"""Poll SRV records for a seedlist.

Returns a list of ServerDescriptions.
Expand Down Expand Up @@ -338,7 +352,7 @@ def _get_seedlist(self):


class _RttMonitor(MonitorBase):
def __init__(self, topology, topology_settings, pool):
def __init__(self, topology: Topology, topology_settings: TopologySettings, pool: Pool):
"""Maintain round trip times for a server.

The Topology is weakly referenced.
Expand All @@ -355,30 +369,30 @@ def __init__(self, topology, topology_settings, pool):
self._moving_min = MovingMinimum()
self._lock = _create_lock()

def close(self):
def close(self) -> None:
self.gc_safe_close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
self._pool.reset()

def add_sample(self, sample):
def add_sample(self, sample: float) -> None:
"""Add a RTT sample."""
with self._lock:
self._moving_average.add_sample(sample)
self._moving_min.add_sample(sample)

def get(self):
def get(self) -> Tuple[Optional[float], float]:
"""Get the calculated average, or None if no samples yet and the min."""
with self._lock:
return self._moving_average.get(), self._moving_min.get()

def reset(self):
def reset(self) -> None:
"""Reset the average RTT."""
with self._lock:
self._moving_average.reset()
self._moving_min.reset()

def _run(self):
def _run(self) -> None:
try:
# NOTE: This thread is only run when using the streaming
# heartbeat protocol (MongoDB 4.4+).
Expand All @@ -391,7 +405,7 @@ def _run(self):
except Exception:
self._pool.reset()

def _ping(self):
def _ping(self) -> float:
"""Run a "hello" command and return the RTT."""
with self._pool.checkout() as conn:
if self._executor._stopped:
Expand All @@ -407,16 +421,16 @@ def _ping(self):
_MONITORS = set()


def _register(monitor):
def _register(monitor: MonitorBase) -> None:
ref = weakref.ref(monitor, _unregister)
_MONITORS.add(ref)


def _unregister(monitor_ref):
def _unregister(monitor_ref: weakref.ReferenceType[MonitorBase]) -> None:
_MONITORS.remove(monitor_ref)


def _shutdown_monitors():
def _shutdown_monitors() -> None:
if _MONITORS is None:
return

Expand All @@ -432,7 +446,7 @@ def _shutdown_monitors():
monitor = None


def _shutdown_resources():
def _shutdown_resources() -> None:
# _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown.
shutdown = _shutdown_monitors
if shutdown: # type:ignore[truthy-function]
Expand Down
2 changes: 1 addition & 1 deletion test/pymongo_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, client, server_description, topology, pool, topology_settings
def _check_once(self):
client = self.client
address = self._server_description.address
response, rtt = client.mock_hello("%s:%d" % address)
response, rtt = client.mock_hello("%s:%d" % address) # type: ignore[str-format]
return ServerDescription(address, Hello(response), rtt)


Expand Down