Skip to content

Commit

Permalink
docs: continued docstring for drivers (#1951)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianmtr committed Feb 16, 2021
1 parent 2143fa2 commit adba5aa
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 79 deletions.
104 changes: 80 additions & 24 deletions jina/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class QuerySetReader:

@property
def as_querylang(self):
"""Render as QueryLang parameters.
# noqa: DAR201"""
parameters = {
name: getattr(self, name) for name in self._init_kwargs_dict.keys()
}
Expand All @@ -119,10 +121,10 @@ def _get_parameter(self, key: str, default: Any):
if getattr(self, 'queryset', None):
for q in self.queryset:
if (
not q.disabled
and self.__class__.__name__ == q.name
and q.priority > self._priority
and key in q.parameters
not q.disabled
and self.__class__.__name__ == q.name
and q.priority > self._priority
and key in q.parameters
):
ret = q.parameters[key]
return dict(ret) if isinstance(ret, Struct) else ret
Expand All @@ -139,12 +141,28 @@ def __getattr__(self, name: str):


class DriverType(type(JAMLCompatible), type):
"""A meta class representing a Driver
When a new Driver is created, it gets registered
"""

def __new__(cls, *args, **kwargs):
"""Create and register a new class with this meta class.
:param *args: *args for super
:param **kwargs: **kwargs for super
:return: the newly registered class
"""
_cls = super().__new__(cls, *args, **kwargs)
return cls.register_class(_cls)

@staticmethod
def register_class(cls):
"""Register a class
:param cls: the class
:return: the class, after being registered
"""
reg_cls_set = getattr(cls, '_registered_class', set())
if cls.__name__ not in reg_cls_set or getattr(cls, 'force_register', False):
wrap_func(cls, ['__init__'], store_init_kwargs)
Expand Down Expand Up @@ -189,12 +207,17 @@ def attach(self, runtime: 'ZEDRuntime', *args, **kwargs) -> None:

@property
def req(self) -> 'Request':
"""Get the current (typed) request, shortcut to ``self.runtime.request``"""
"""Get the current (typed) request, shortcut to ``self.runtime.request``
# noqa: DAR201
"""
return self.runtime.request

@property
def partial_reqs(self) -> Sequence['Request']:
"""The collected partial requests under the current ``request_id`` """
"""The collected partial requests under the current ``request_id``
# noqa: DAR401
# noqa: DAR201
"""
if self.expect_parts > 1:
return self.runtime.partial_requests
else:
Expand All @@ -205,27 +228,42 @@ def partial_reqs(self) -> Sequence['Request']:

@property
def expect_parts(self) -> int:
"""The expected number of partial messages """
"""The expected number of partial messages
# noqa: DAR201
"""
return self.runtime.expect_parts

@property
def msg(self) -> 'Message':
"""Get the current request, shortcut to ``self.runtime.message``"""
"""Get the current request, shortcut to ``self.runtime.message``
# noqa: DAR201
"""
return self.runtime.message

@property
def queryset(self) -> 'QueryLangSet':
"""
# noqa: DAR101
# noqa: DAR102
# noqa: DAR201
"""
if self.msg:
return self.msg.request.queryset
else:
return []

@property
def logger(self) -> 'JinaLogger':
"""Shortcut to ``self.runtime.logger``"""
"""Shortcut to ``self.runtime.logger``
# noqa: DAR201
"""
return self.runtime.logger

def __call__(self, *args, **kwargs) -> None:
"""
# noqa: DAR102
# noqa: DAR101
"""
raise NotImplementedError

def __eq__(self, other):
Expand All @@ -249,29 +287,33 @@ class RecursiveMixin(BaseDriver):

@property
def docs(self):
"""
# noqa: DAR102
# noqa: DAR201
"""
if self.expect_parts > 1:
return (d for r in reversed(self.partial_reqs) for d in r.docs)
else:
return self.req.docs

def _apply_root(
self,
docs: 'DocumentSet',
field: str,
*args,
**kwargs,
self,
docs: 'DocumentSet',
field: str,
*args,
**kwargs,
) -> None:
return self._apply_all(docs, None, field, *args, **kwargs)

# TODO(Han): probably want to publicize this, as it is not obvious for driver
# developer which one should be inherited
def _apply_all(
self,
docs: 'DocumentSet',
context_doc: 'Document',
field: str,
*args,
**kwargs,
self,
docs: 'DocumentSet',
context_doc: 'Document',
field: str,
*args,
**kwargs,
) -> None:
"""Apply function works on a list of docs, modify the docs in-place
Expand All @@ -283,6 +325,11 @@ def _apply_all(
"""

def __call__(self, *args, **kwargs):
"""Call the Driver.
:param *args: *args for ``_traverse_apply``
:param **kwargs: **kwargs for ``_traverse_apply``
"""
self._traverse_apply(self.docs, *args, **kwargs)

def _traverse_apply(self, docs: 'DocumentSet', *args, **kwargs) -> None:
Expand Down Expand Up @@ -326,10 +373,17 @@ class FastRecursiveMixin:
"""

def __call__(self, *args, **kwargs):
"""Traverse with _apply_all
:param *args: *args for ``_apply_all``
:param **kwargs: **kwargs for ``_apply_all``
"""
self._apply_all(self.docs, *args, **kwargs)

@property
def docs(self) -> 'DocumentSet':
"""The DocumentSet after applying the traversal
# noqa: DAR201"""
from ..types.sets import DocumentSet

if self.expect_parts > 1:
Expand Down Expand Up @@ -386,7 +440,9 @@ def __init__(self, executor: Optional[str] = None,

@property
def exec(self) -> 'AnyExecutor':
"""the executor that to which the instance is attached"""
"""the executor that to which the instance is attached
# noqa: DAR201
"""
return self._exec

@property
Expand All @@ -396,8 +452,8 @@ def exec_fn(self) -> Callable:
:return: the Callable to execute in the driver
"""
if (
not self.msg.is_error
or self.runtime.args.on_error_strategy < OnErrorStrategy.SKIP_EXECUTOR
not self.msg.is_error
or self.runtime.args.on_error_strategy < OnErrorStrategy.SKIP_EXECUTOR
):
return self._exec_fn
else:
Expand All @@ -417,7 +473,7 @@ def attach(self, executor: 'AnyExecutor', *args, **kwargs) -> None:
else:
for c in executor.components:
if any(
t.__name__ == self._executor_name for t in type.mro(c.__class__)
t.__name__ == self._executor_name for t in type.mro(c.__class__)
):
self._exec = c
break
Expand Down
44 changes: 32 additions & 12 deletions jina/drivers/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ class BaseControlDriver(BaseDriver):

@property
def envelope(self) -> 'jina_pb2.EnvelopeProto':
"""Get the current request, shortcut to ``self.runtime.message``"""
"""Get the current request, shortcut to ``self.runtime.message``
# noqa: DAR201
"""
return self.msg.envelope


Expand All @@ -28,14 +30,19 @@ def __init__(self, key: str = 'request', json: bool = True, *args, **kwargs):
"""
:param key: (str) that represents a first level or nested key in the dict
:param json: (bool) indicating if the log output should be formatted as json
:param *args:
:param **kwargs:
:param *args: *args for super
:param **kwargs: **kwargs for super
"""
super().__init__(*args, **kwargs)
self.key = key
self.json = json

def __call__(self, *args, **kwargs):
"""Log the information.
:param *args: unused
:param **kwargs: unused
"""
data = dunder_get(self.msg.proto, self.key)
if self.json:
self.logger.info(
Expand All @@ -49,13 +56,20 @@ class WaitDriver(BaseControlDriver):
"""Wait for some seconds, mainly for demo purpose"""

def __call__(self, *args, **kwargs):
"""Wait for some seconds, mainly for demo purpose
# noqa: DAR101"""
time.sleep(5)


class ControlReqDriver(BaseControlDriver):
"""Handling the control request, by default it is installed for all :class:`jina.peapods.peas.BasePea`"""

def __call__(self, *args, **kwargs):
"""Handle the request controlling.
:param *args: unused
:param **kwargs: unused
"""
if self.req.command == 'TERMINATE':
self.envelope.status.code = jina_pb2.StatusProto.SUCCESS
raise RuntimeTerminated
Expand All @@ -79,28 +93,34 @@ class RouteDriver(ControlReqDriver):
- The router receives requests from both dealer and upstream pusher.
if it is an upstream request, use LB to schedule the receiver,
mark it in the envelope if it is a control request in
:param raise_no_dealer: raise a RuntimeError when no available dealer
:param *args: *args for super
:param **kwargs: **kwargs for super
"""

def __init__(self, raise_no_dealer: bool = False, *args, **kwargs):
"""
:param raise_no_dealer: raise a RuntimeError when no available dealer
:param *args:
:param **kwargs:
"""
super().__init__(*args, **kwargs)
self.idle_dealer_ids = set()
self.is_pollin_paused = False
self.is_polling_paused = False
self.raise_no_dealer = raise_no_dealer

def __call__(self, *args, **kwargs):
"""Perform the routing.
:param *args: *args for super().__call__
:param **kwargs: **kwargs for super().__call__
# noqa: DAR401
"""
if self.msg.is_data_request:
self.logger.debug(self.idle_dealer_ids)
if self.idle_dealer_ids:
dealer_id = self.idle_dealer_ids.pop()
self.envelope.receiver_id = dealer_id
if not self.idle_dealer_ids:
self.runtime._zmqlet.pause_pollin()
self.is_pollin_paused = True
self.is_polling_paused = True
elif self.raise_no_dealer:
raise RuntimeError('if this router connects more than one dealer, '
'then this error should never be raised. often when it '
Expand All @@ -119,9 +139,9 @@ def __call__(self, *args, **kwargs):
elif self.req.command == 'IDLE':
self.idle_dealer_ids.add(self.envelope.receiver_id)
self.logger.debug(f'{self.envelope.receiver_id} is idle, now I know these idle peas {self.idle_dealer_ids}')
if self.is_pollin_paused:
if self.is_polling_paused:
self.runtime._zmqlet.resume_pollin()
self.is_pollin_paused = False
self.is_polling_paused = False
raise NoExplicitMessage
else:
super().__call__(*args, **kwargs)
Expand Down
19 changes: 17 additions & 2 deletions jina/drivers/convertdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class ConvertDriver(RecursiveMixin, BaseRecursiveDriver):
def __init__(self, convert_fn: str, *args, **kwargs):
"""
:param convert_fn: the method name from `:class:`Document` to be applied
:param *args:
:param **kwargs: the set of named arguments to be passed to the method with name `convert_fn`
:param *args: *args for super
:param **kwargs: the set of named arguments to be passed to `convert_fn`
"""
super().__init__(*args, **kwargs)
self._convert_fn = convert_fn
Expand All @@ -32,40 +32,55 @@ def _apply_all(


class URI2Buffer(ConvertDriver):
"""Driver to convert URI to buffer"""
def __init__(self, convert_fn: str = 'convert_uri_to_buffer', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class URI2DataURI(ConvertDriver):
"""Driver to convert URI to data URI
"""
def __init__(self, convert_fn: str = 'convert_uri_to_data_uri', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class Buffer2URI(ConvertDriver):
"""Driver to convert buffer to URI
"""
def __init__(self, convert_fn: str = 'convert_buffer_to_uri', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class BufferImage2Blob(ConvertDriver):
"""Driver to convert image buffer to blob
"""
def __init__(self, convert_fn: str = 'convert_buffer_image_to_blob', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class URI2Blob(ConvertDriver):
"""Driver to convert URI to blob
"""
def __init__(self, convert_fn: str = 'convert_uri_to_blob', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class Text2URI(ConvertDriver):
"""Driver to convert text to URI
"""
def __init__(self, convert_fn: str = 'convert_text_to_uri', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class URI2Text(ConvertDriver):
"""Driver to convert URI to text
"""
def __init__(self, convert_fn: str = 'convert_uri_to_text', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class Blob2PngURI(ConvertDriver):
"""Driver to convert blob to URI
"""
def __init__(self, convert_fn: str = 'convert_blob_to_uri', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)

0 comments on commit adba5aa

Please sign in to comment.