Skip to content

Commit

Permalink
type annotate all functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cenkalti committed Mar 7, 2018
1 parent ef86f7e commit d2d4534
Show file tree
Hide file tree
Showing 11 changed files with 273 additions and 229 deletions.
133 changes: 1 addition & 132 deletions kuyruk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import sys
import json
import logging
import pkg_resources
from contextlib import contextmanager
from typing import Dict, Any, Set # noqa

import amqp

from kuyruk.kuyruk import Kuyruk
from kuyruk.task import Task
from kuyruk.config import Config
from kuyruk.worker import Worker
from kuyruk import signals

__all__ = ['Kuyruk', 'Config', 'Task', 'Worker']

Expand All @@ -21,128 +15,3 @@
# Add NullHandler to prevent logging warnings on startup
null_handler = logging.NullHandler()
logger.addHandler(null_handler)


class Kuyruk:
"""
Provides :func:`~kuyruk.Kuyruk.task` decorator to convert a function
into a :class:`~kuyruk.Task`.
Provides :func:`~kuyruk.Kuyruk.channel` context manager for opening a
new channel on the connection.
Connection is opened when the first channel is created.
:param config: Must be an instance of :class:`~kuyruk.Config`.
If ``None``, default config is used.
See :class:`~kuyruk.Config` for default values.
"""
def __init__(self, config=None):
if config is None:
config = Config()
if not isinstance(config, Config):
raise TypeError
self.config = config
self.extensions = {} # type: Dict[str, Any]

def task(self, queue='kuyruk', **kwargs):
"""
Wrap functions with this decorator to convert them to *tasks*.
After wrapping, calling the function will send a message to
a queue instead of running the function.
:param queue: Queue name for the tasks.
:param kwargs: Keyword arguments will be passed to
:class:`~kuyruk.Task` constructor.
:return: Callable :class:`~kuyruk.Task` object wrapping the original
function.
"""
def decorator():
def inner(f):
# Function may be wrapped with no-arg decorator
queue_ = 'kuyruk' if callable(queue) else queue

return Task(f, self, queue_, **kwargs)

return inner

return decorator()

@contextmanager
def channel(self):
"""Returns a new channel from a new connection as a context manager.
"""
with self.connection() as conn:
ch = conn.channel()
logger.info('Opened new channel')
with _safe_close(ch):
yield ch

@contextmanager
def connection(self):
"""Returns a new connection as a context manager."""
conn = amqp.Connection(
host="%s:%s" % (self.config.RABBIT_HOST, self.config.RABBIT_PORT),
userid=self.config.RABBIT_USER,
password=self.config.RABBIT_PASSWORD,
virtual_host=self.config.RABBIT_VIRTUAL_HOST,
connect_timeout=self.config.RABBIT_CONNECT_TIMEOUT,
read_timeout=self.config.RABBIT_READ_TIMEOUT,
write_timeout=self.config.RABBIT_WRITE_TIMEOUT,
)
conn.connect()
logger.info('Connected to RabbitMQ')
with _safe_close(conn):
yield conn

def send_tasks_to_queue(self, subtasks):
if self.config.EAGER:
for subtask in subtasks:
subtask.task.apply(*subtask.args, **subtask.kwargs)
return

declared_queues = set() # type: Set[str]
with self.channel() as ch:
for subtask in subtasks:
queue = subtask.task._queue_for_host(subtask.host)
if queue not in declared_queues:
ch.queue_declare(queue=queue, durable=True, auto_delete=False)
declared_queues.add(queue)

description = subtask.task._get_description(subtask.args,
subtask.kwargs)
subtask.task._send_signal(signals.task_presend,
args=subtask.args,
kwargs=subtask.kwargs,
description=description)

body = json.dumps(description)
msg = amqp.Message(body=body)
ch.basic_publish(msg, exchange="", routing_key=queue)
subtask.task._send_signal(signals.task_postsend,
args=subtask.args,
kwargs=subtask.kwargs,
description=description)


@contextmanager
def _safe_close(obj):
try:
yield
except Exception:
# Error occurred in block. Save exception info for re-raising later.
exc_info = sys.exc_info()

# We still need to close the object but not interested with errors,
# because we will raise the original exception above.
try:
obj.close()
except Exception:
pass

# After closing the object, we are re-raising the saved exception.
raise exc_info[1].with_traceback(exc_info[2])
else:
# No error occurred in block. We must close the object as usual.
obj.close()
4 changes: 2 additions & 2 deletions kuyruk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)


def main():
def main() -> None:
parser = argparse.ArgumentParser(conflict_handler='resolve')

# Add common options
Expand Down Expand Up @@ -73,7 +73,7 @@ def main():
args.func(app, args)


def run_worker(app, args):
def run_worker(app: Kuyruk, args: argparse.Namespace) -> None:
w = Worker(app, args)
w.run()

Expand Down
14 changes: 7 additions & 7 deletions kuyruk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import types
import logging
import pkg_resources
from typing import Dict, Any # noqa
from typing import Dict, Any, Union # noqa

from kuyruk import importer

Expand Down Expand Up @@ -57,7 +57,7 @@ class attributes. Additional attributes may be added by extensions.
WORKER_LOGGING_LEVEL = 'INFO'
"""Logging level of root logger."""

def from_object(self, obj):
def from_object(self, obj: Union[str, Any]) -> None:
"""Load values from an object."""
if isinstance(obj, str):
obj = importer.import_object_str(obj)
Expand All @@ -67,14 +67,14 @@ def from_object(self, obj):
self._setattr(key, value)
logger.info("Config is loaded from object: %r", obj)

def from_dict(self, d):
def from_dict(self, d: Dict[str, Any]) -> None:
"""Load values from a dict."""
for key, value in d.items():
if key.isupper():
self._setattr(key, value)
logger.info("Config is loaded from dict: %r", d)

def from_pymodule(self, name):
def from_pymodule(self, name: str) -> None:
if not isinstance(name, str):
raise TypeError
module = importer.import_module(name)
Expand All @@ -83,7 +83,7 @@ def from_pymodule(self, name):
self._setattr(key, value)
logger.info("Config is loaded from module: %s", name)

def from_pyfile(self, filename):
def from_pyfile(self, filename: str) -> None:
"""Load values from a Python file."""
globals_ = {} # type: Dict[str, Any]
locals_ = {} # type: Dict[str, Any]
Expand All @@ -96,7 +96,7 @@ def from_pyfile(self, filename):

logger.info("Config is loaded from file: %s", filename)

def from_env_vars(self):
def from_env_vars(self) -> None:
"""Load values from environment variables.
Keys must start with `KUYRUK_`."""
for key, value in os.environ.items():
Expand All @@ -109,7 +109,7 @@ def from_env_vars(self):
pass
self._setattr(key, value)

def _setattr(self, key, value):
def _setattr(self, key: str, value: Any) -> None:
if not hasattr(self.__class__, key):
raise ValueError("Unknown config key: %s" % key)
setattr(self, key, value)
Expand Down
12 changes: 9 additions & 3 deletions kuyruk/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from typing import Tuple, Type
from types import TracebackType

ExcInfoType = Tuple[Type[BaseException], BaseException, TracebackType]


class KuyrukError(Exception):
"""Base class for Kuyruk exceptions."""
pass
Expand Down Expand Up @@ -41,12 +47,12 @@ class RemoteException(KuyrukError):
exception is raised on the worker while running the task.
"""
def __init__(self, type_, value, traceback):
def __init__(self, type_: Type, value: Exception, traceback: TracebackType) -> None:
self.type = type_
self.value = value
self.traceback = traceback

def __str__(self):
def __str__(self) -> str:
return "%s(%r)" % (self.type, self.value)


Expand All @@ -55,5 +61,5 @@ class HeartbeatError(KuyrukError):
Raised when there is problem while sending heartbeat during task execution.
"""
def __init__(self, exc_info):
def __init__(self, exc_info: ExcInfoType) -> None:
self.exc_info = exc_info
11 changes: 7 additions & 4 deletions kuyruk/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@
import socket
import logging
import threading
from typing import Callable, Tuple, Any

import amqp

logger = logging.getLogger(__name__)


class Heartbeat:

def __init__(self, connection, on_error):
def __init__(self, connection: amqp.Connection, on_error: Callable[[Tuple[Any, Any, Any]], None]) -> None:
self._connection = connection
self._on_error = on_error
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run)

def start(self):
def start(self) -> None:
self._thread.start()

def stop(self):
def stop(self) -> None:
self._stop.set()
self._thread.join()

def _run(self):
def _run(self) -> None:
while not self._stop.is_set():
try:
self._connection.heartbeat_tick()
Expand Down
30 changes: 15 additions & 15 deletions kuyruk/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,47 @@
import sys
import logging
import importlib
from collections import namedtuple
from typing import Any
from types import ModuleType

logger = logging.getLogger(__name__)

main_module = sys.modules['__main__']


def import_module(name):
def import_module(name: str) -> ModuleType:
"""Import module by it's name from following places in order:
- main module
- current working directory
- Python path
"""
logger.debug("Importing module: %s", name)
module, main_module_name = get_main_module()
if name == main_module_name:
return module
if name == main_module_name():
return main_module

return importlib.import_module(name)


def import_object(module_name, object_name):
def import_object(module_name: str, object_name: str) -> Any:
module = import_module(module_name)
try:
return getattr(module, object_name)
except AttributeError as e:
raise ImportError(e)


def import_object_str(s):
def import_object_str(s: str) -> Any:
module, obj = s.rsplit('.', 1)
return import_object(module, obj)


def get_main_module():
def main_module_name() -> str:
"""Returns main module and module name pair."""
if not hasattr(main_module, '__file__'):
# if run from interactive shell
return None, None
# running from interactive shell
return None

main_filename = os.path.basename(main_module.__file__)
module_name, ext = os.path.splitext(main_filename)
return FakeModule(module=main_module, name=module_name)


FakeModule = namedtuple('MainModule', ['module', 'name'])
main_module = sys.modules['__main__']
return module_name
Loading

0 comments on commit d2d4534

Please sign in to comment.