Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Additional typing in synapse.util.metrics
Browse files Browse the repository at this point in the history
Didn't get this to pass `no-untyped-def`, think I'll need to watch #10847
  • Loading branch information
David Robertson committed Sep 22, 2021
1 parent f3ad101 commit 0a70837
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions synapse/util/metrics.py
Expand Up @@ -14,7 +14,8 @@

import logging
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
from types import TracebackType
from typing import Any, Callable, Optional, Protocol, Type, TypeVar, cast

from prometheus_client import Counter

Expand All @@ -24,6 +25,7 @@
current_context,
)
from synapse.metrics import InFlightGauge
from synapse.util import Clock

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -61,10 +63,15 @@
sub_metrics=["real_time_max", "real_time_sum"],
)

T = TypeVar("T", bound=Callable[..., Any])
R = TypeVar("R")
F = Callable[..., R]


def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
class HasClock(Protocol):
clock: Clock


def measure_func(name: Optional[str] = None) -> Callable[[F], F]:
"""
Used to decorate an async function with a `Measure` context manager.
Expand All @@ -82,16 +89,16 @@ async def foo(...):
"""

def wrapper(func: T) -> T:
def wrapper(func: F) -> F:
block_name = func.__name__ if name is None else name

@wraps(func)
async def measured_func(self, *args, **kwargs):
async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> R:
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r

return cast(T, measured_func)
return cast(F, measured_func)

return wrapper

Expand All @@ -104,10 +111,10 @@ class Measure:
"start",
]

def __init__(self, clock, name: str):
def __init__(self, clock: Clock, name: str) -> None:
"""
Args:
clock: A n object with a "time()" method, which returns the current
clock: An object with a "time()" method, which returns the current
time in seconds.
name: The name of the metric to report.
"""
Expand All @@ -124,7 +131,7 @@ def __init__(self, clock, name: str):
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
self._logging_context = LoggingContext(str(curr_context), parent_context)
self.start: Optional[int] = None
self.start: Optional[float] = None

def __enter__(self) -> "Measure":
if self.start is not None:
Expand All @@ -138,7 +145,12 @@ def __enter__(self) -> "Measure":

return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if self.start is None:
raise RuntimeError("Measure() block exited without being entered")

Expand Down Expand Up @@ -168,8 +180,9 @@ def get_resource_usage(self) -> ContextResourceUsage:
"""
return self._logging_context.get_resource_usage()

def _update_in_flight(self, metrics):
def _update_in_flight(self, metrics) -> None:
"""Gets called when processing in flight metrics"""
assert self.start is not None
duration = self.clock.time() - self.start

metrics.real_time_max = max(metrics.real_time_max, duration)
Expand Down

0 comments on commit 0a70837

Please sign in to comment.