-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
hanhxiao
committed
Dec 18, 2018
1 parent
94d9065
commit eafb30f
Showing
2 changed files
with
101 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from contextlib import ExitStack | ||
|
||
from zmq.decorators import _Decorator | ||
|
||
__all__ = ['multi_socket'] | ||
|
||
from functools import wraps | ||
|
||
import zmq | ||
|
||
|
||
class _MyDecorator(_Decorator): | ||
def __call__(self, *dec_args, **dec_kwargs): | ||
kw_name, dec_args, dec_kwargs = self.process_decorator_args(*dec_args, **dec_kwargs) | ||
num_socket = dec_kwargs.pop('num_socket') | ||
|
||
def decorator(func): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
targets = [self.get_target(*args, **kwargs) for _ in range(num_socket)] | ||
with ExitStack() as stack: | ||
for target in targets: | ||
obj = stack.enter_context(target(*dec_args, **dec_kwargs)) | ||
args = args + (obj,) | ||
|
||
return func(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
class _SocketDecorator(_MyDecorator): | ||
def process_decorator_args(self, *args, **kwargs): | ||
"""Also grab context_name out of kwargs""" | ||
kw_name, args, kwargs = super(_SocketDecorator, self).process_decorator_args(*args, **kwargs) | ||
self.context_name = kwargs.pop('context_name', 'context') | ||
return kw_name, args, kwargs | ||
|
||
def get_target(self, *args, **kwargs): | ||
"""Get context, based on call-time args""" | ||
context = self._get_context(*args, **kwargs) | ||
return context.socket | ||
|
||
def _get_context(self, *args, **kwargs): | ||
if self.context_name in kwargs: | ||
ctx = kwargs[self.context_name] | ||
|
||
if isinstance(ctx, zmq.Context): | ||
return ctx | ||
|
||
for arg in args: | ||
if isinstance(arg, zmq.Context): | ||
return arg | ||
# not specified by any decorator | ||
return zmq.Context.instance() | ||
|
||
|
||
def multi_socket(*args, **kwargs): | ||
return _SocketDecorator()(*args, **kwargs) |