Skip to content

Commit

Permalink
Fully type test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
jodal committed Jun 7, 2023
1 parent 1be1e04 commit 3d92e9c
Show file tree
Hide file tree
Showing 21 changed files with 994 additions and 473 deletions.
6 changes: 1 addition & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,14 @@ pytest-mock = "^3.10.0"
target-version = ["py38", "py39", "py310", "py311"]

[tool.mypy]
disallow_untyped_defs = true
no_implicit_optional = true
strict_equality = true
warn_return_any = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_unused_configs = true

[[tool.mypy.overrides]]
# Options that only applies to src, and not tests
module = "pykka.*"
disallow_untyped_defs = true

[tool.ruff]
select = [
"A", # flake8-builtins
Expand Down
3 changes: 3 additions & 0 deletions src/pykka/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(self) -> None:
self._get_hook = None
self._get_hook_result = None

def __repr__(self) -> str:
return "<pykka.Future>"

def get(
self,
timeout: Optional[float] = None,
Expand Down
73 changes: 39 additions & 34 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

import logging
import threading
import time
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
Iterator,
cast,
)

import pytest

from pykka import ActorRegistry, ThreadingActor, ThreadingFuture
from tests.log_handler import PykkaTestLogHandler
from tests.types import Events, Runtime

Runtime = namedtuple(
"Runtime",
["name", "actor_class", "event_class", "future_class", "sleep_func"],
)
if TYPE_CHECKING:
from pykka import Actor, Future


RUNTIMES = {
Expand All @@ -29,18 +35,18 @@


@pytest.fixture(scope="session", params=RUNTIMES.values())
def runtime(request):
return request.param
def runtime(request: pytest.FixtureRequest) -> Runtime:
return cast(Runtime, request.param)


@pytest.fixture()
def _stop_all():
def _stop_all() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
yield
ActorRegistry.stop_all()


@pytest.fixture()
def log_handler():
def log_handler() -> Iterator[logging.Handler]:
log_handler = PykkaTestLogHandler()

root_logger = logging.getLogger()
Expand All @@ -55,25 +61,24 @@ def log_handler():


@pytest.fixture()
def events(runtime):
class Events:
on_start_was_called = runtime.event_class()
on_stop_was_called = runtime.event_class()
on_failure_was_called = runtime.event_class()
greetings_was_received = runtime.event_class()
actor_registered_before_on_start_was_called = runtime.event_class()

return Events()
def events(runtime: Runtime) -> Events:
return Events(
on_start_was_called=runtime.event_class(),
on_stop_was_called=runtime.event_class(),
on_failure_was_called=runtime.event_class(),
greetings_was_received=runtime.event_class(),
actor_registered_before_on_start_was_called=runtime.event_class(),
)


@pytest.fixture(scope="module")
def early_failing_actor_class(runtime):
class EarlyFailingActor(runtime.actor_class):
def __init__(self, events):
def early_failing_actor_class(runtime: Runtime) -> type[Actor]:
class EarlyFailingActor(runtime.actor_class): # type: ignore[name-defined] # noqa: E501
def __init__(self, events: Events) -> None:
super().__init__()
self.events = events

def on_start(self):
def on_start(self) -> None:
try:
raise RuntimeError("on_start failure")
finally:
Expand All @@ -83,16 +88,16 @@ def on_start(self):


@pytest.fixture(scope="module")
def late_failing_actor_class(runtime):
class LateFailingActor(runtime.actor_class):
def __init__(self, events):
def late_failing_actor_class(runtime: Runtime) -> type[Actor]:
class LateFailingActor(runtime.actor_class): # type: ignore[name-defined] # noqa: E501
def __init__(self, events: Events) -> None:
super().__init__()
self.events = events

def on_start(self):
def on_start(self) -> None:
self.stop()

def on_stop(self):
def on_stop(self) -> None:
try:
raise RuntimeError("on_stop failure")
finally:
Expand All @@ -102,18 +107,18 @@ def on_stop(self):


@pytest.fixture(scope="module")
def failing_on_failure_actor_class(runtime):
class FailingOnFailureActor(runtime.actor_class):
def __init__(self, events):
def failing_on_failure_actor_class(runtime: Runtime) -> type[Actor]:
class FailingOnFailureActor(runtime.actor_class): # type: ignore[name-defined] # noqa: E501
def __init__(self, events: Events) -> None:
super().__init__()
self.events = events

def on_receive(self, message):
def on_receive(self, message: Any) -> Any:
if message.get("command") == "raise exception":
raise Exception("on_receive failure")
return super().on_receive(message)

def on_failure(self, *args):
def on_failure(self, *args: Any) -> None:
try:
raise RuntimeError("on_failure failure")
finally:
Expand All @@ -123,10 +128,10 @@ def on_failure(self, *args):


@pytest.fixture()
def future(runtime):
def future(runtime: Runtime) -> Future[Any]:
return runtime.future_class()


@pytest.fixture()
def futures(runtime):
def futures(runtime: Runtime) -> list[Future[Any]]:
return [runtime.future_class() for _ in range(3)]
32 changes: 25 additions & 7 deletions tests/log_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,48 @@
import logging
import threading
import time
from enum import Enum
from typing import Any, Dict, List


class LogLevel(str, Enum):
DEBUG = "debug"
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"


class PykkaTestLogHandler(logging.Handler):
def __init__(self, *args, **kwargs):
self.lock = threading.RLock()
lock: threading.RLock # type: ignore[assignment]
events: Dict[str, threading.Event]
messages: Dict[LogLevel, List[logging.LogRecord]]

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.lock = ( # pyright: ignore[reportIncompatibleVariableOverride]
threading.RLock()
)
with self.lock:
self.events = collections.defaultdict(threading.Event)
self.messages = {}
self.reset()
logging.Handler.__init__(self, *args, **kwargs)

def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
with self.lock:
level = record.levelname.lower()
level = LogLevel(record.levelname.lower())
self.messages[level].append(record)
self.events[level].set()

def reset(self):
def reset(self) -> None:
with self.lock:
for level in ("debug", "info", "warning", "error", "critical"):
for level in LogLevel:
self.events[level].clear()
self.messages[level] = []

def wait_for_message(self, level, num_messages=1, timeout=5):
def wait_for_message(
self, level: LogLevel, num_messages: int = 1, timeout: float = 5
) -> None:
"""Wait until at least ``num_messages`` log messages have been emitted
to the given log level."""
deadline = time.time() + timeout
Expand Down
19 changes: 11 additions & 8 deletions tests/performance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# ruff: noqa: T201

from __future__ import annotations

import time
from typing import Any, Callable

from pykka import ActorRegistry, ThreadingActor


def time_it(func):
def time_it(func: Callable[[], Any]) -> None:
start = time.time()
func()
elapsed = time.time() - start
Expand All @@ -16,7 +19,7 @@ class SomeObject:
pykka_traversable = False
cat = "bar.cat"

def func(self):
def func(self) -> None:
pass


Expand All @@ -26,33 +29,33 @@ class AnActor(ThreadingActor):

foo = "foo"

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.cat = "quox"

def func(self):
def func(self) -> None:
pass


def test_direct_plain_attribute_access():
def test_direct_plain_attribute_access() -> None:
actor = AnActor.start().proxy()
for _ in range(10000):
actor.foo.get()


def test_direct_callable_attribute_access():
def test_direct_callable_attribute_access() -> None:
actor = AnActor.start().proxy()
for _ in range(10000):
actor.func().get()


def test_traversable_plain_attribute_access():
def test_traversable_plain_attribute_access() -> None:
actor = AnActor.start().proxy()
for _ in range(10000):
actor.bar.cat.get()


def test_traversable_callable_attribute_access():
def test_traversable_callable_attribute_access() -> None:
actor = AnActor.start().proxy()
for _ in range(10000):
actor.bar.func().get()
Expand Down

0 comments on commit 3d92e9c

Please sign in to comment.