Skip to content

Commit

Permalink
test: rescue old invalid test (#3707)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Oct 19, 2021
1 parent 86a6b98 commit a078c07
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 36 deletions.
12 changes: 11 additions & 1 deletion jina/peapods/peas/__init__.py
Expand Up @@ -96,6 +96,7 @@ def _set_envs():
finally:
_unset_envs()
is_shutdown.set()
logger.debug(f' Process terminated')


class BasePea:
Expand Down Expand Up @@ -194,14 +195,18 @@ def join(self, *args, **kwargs):
:param args: extra positional arguments to pass to join
:param kwargs: extra keyword arguments to pass to join
"""
self.logger.debug(f' Joining the process')
self.worker.join(*args, **kwargs)
self.logger.debug(f' Successfully joined the process')

def terminate(self):
"""Terminate the Pea.
This method calls :meth:`terminate` in :class:`threading.Thread` or :class:`multiprocesssing.Process`.
"""
if hasattr(self.worker, 'terminate'):
self.logger.debug(f' terminating the runtime process')
self.worker.terminate()
self.logger.debug(f' runtime process properly terminated')

def _retry_control_message(self, command: str, num_retry: int = 3):
from ..zmq import send_ctrl_message
Expand Down Expand Up @@ -345,11 +350,15 @@ def close(self) -> None:
"""
# if that 1s is not enough, it means the process/thread is still in forever loop, cancel it
self.logger.debug('waiting for ready or shutdown signal from runtime')
terminated = False
if self.is_ready.is_set() and not self.is_shutdown.is_set():
try:
self.logger.warning(f' Cancel runtime')
self._cancel_runtime()
self.logger.warning(f' Wait to shutdown')
if not self.is_shutdown.wait(timeout=self._timeout_ctrl):
self.terminate()
terminated = True
time.sleep(0.1)
raise Exception(
f'Shutdown signal was not received for {self._timeout_ctrl}'
Expand All @@ -362,7 +371,8 @@ def close(self) -> None:
else '',
exc_info=not self.args.quiet_error,
)
self.terminate()
if not terminated:
self.terminate()

# if it is not daemon, block until the process/thread finish work
if not self.args.daemon:
Expand Down
33 changes: 16 additions & 17 deletions tests/unit/test_gateway.py
Expand Up @@ -12,45 +12,44 @@

@pytest.mark.slow
@pytest.mark.parametrize('compress_algo', list(CompressAlgo))
def test_compression(compress_algo, mocker):

response_mock = mocker.Mock()

def test_compression(compress_algo):
f = Flow(compress=str(compress_algo)).add().add(name='DummyEncoder', shards=2).add()

with f:
f.index(random_docs(10), on_done=response_mock)
results = f.index(random_docs(10), return_results=True)

response_mock.assert_called()
assert len(results) > 0


@pytest.mark.slow
@pytest.mark.parametrize('protocol', ['websocket', 'grpc', 'http'])
def test_gateway_concurrency(protocol):
def test_gateway_concurrency(protocol, reraise):
PORT_EXPOSE = 12345
CONCURRENCY = 2
threads = []
status_codes = [None] * CONCURRENCY
durations = [None] * CONCURRENCY

def _validate(req, start, status_codes, durations, index):
end = time.time()
durations[index] = end - start
status_codes[index] = req.status.code

def _request(status_codes, durations, index):
start = time.time()
Client(port=PORT_EXPOSE, protocol=protocol).index(
inputs=(Document() for _ in range(256)),
on_done=functools.partial(
with reraise:
start = time.time()
on_done = functools.partial(
_validate,
start=start,
status_codes=status_codes,
durations=durations,
index=index,
),
batch_size=16,
)
)
results = Client(port=PORT_EXPOSE, protocol=protocol).index(
inputs=(Document() for _ in range(256)),
return_results=True,
_size=16,
)
assert len(results) > 0
for result in results:
on_done(result)

f = Flow(protocol=protocol, port_expose=PORT_EXPOSE).add(parallel=2)
with f:
Expand Down
@@ -1,26 +1,50 @@
import time

from jina import __default_executor__
from jina.helper import random_identity
from jina.logging.predefined import default_logger
from jina.parsers import set_pea_parser
from jina.peapods.runtimes.zmq.zed import ZEDRuntime
from jina.peapods.peas import BasePea
from jina.peapods.zmq import Zmqlet
from jina.types.message import Message
from jina.types.request import Request
from jina import Executor, requests
from tests import validate_callback


class MockBasePeaNotRead(BasePea):
def _post_hook(self, msg: 'Message') -> 'BasePea':
class DecompressExec(Executor):
@requests()
def func(self, docs, **kwargs):
for doc in docs:
doc.text = 'used'


class MockRuntimeNotDecompressed(ZEDRuntime):
def _post_hook(self, msg: 'Message'):
super()._post_hook(msg)
assert not msg.request.is_decompressed
if msg is not None:
decompressed = msg.request.is_decompressed
if msg.is_data_request:
assert not decompressed
return msg


class MockBasePeaRead(BasePea):
def _post_hook(self, msg: 'Message') -> 'BasePea':
class MockRuntimeDecompressed(ZEDRuntime):
def _post_hook(self, msg: 'Message'):
super()._post_hook(msg)
assert msg.request.is_decompressed
if msg is not None:
decompressed = msg.request.is_decompressed
if msg.is_data_request:
assert decompressed
return msg


class MockPea(BasePea):
def _get_runtime_cls(self):
if self.args.runtime_cls == 'MockRuntimeNotDecompressed':
return MockRuntimeNotDecompressed
else:
return MockRuntimeDecompressed


args1 = set_pea_parser().parse_args(
Expand All @@ -33,8 +57,6 @@ def _post_hook(self, msg: 'Message') -> 'BasePea':
'PULL_CONNECT',
'--socket-out',
'PUSH_CONNECT',
'--timeout-ctrl',
'-1',
]
)

Expand All @@ -52,8 +74,8 @@ def _post_hook(self, msg: 'Message') -> 'BasePea':
'PULL_BIND',
'--socket-out',
'PUSH_BIND',
'--timeout-ctrl',
'-1',
'--runtime-cls',
'MockRuntimeNotDecompressed',
]
)

Expand All @@ -72,31 +94,54 @@ def _post_hook(self, msg: 'Message') -> 'BasePea':
'--socket-out',
'PUSH_BIND',
'--uses',
__default_executor__, # will NOT trigger use
'--timeout-ctrl',
'-1',
'DecompressExec',
'--runtime-cls',
'MockRuntimeDecompressed',
]
)


def test_read_zmqlet():
with MockBasePeaRead(args2), Zmqlet(args1, default_logger) as z:
def test_not_decompressed_zmqlet(mocker):
with MockPea(args2) as pea, Zmqlet(args1, default_logger) as z:
req = Request()
req.request_id = random_identity()
d = req.data.docs.add()
d.tags['id'] = 2
msg = Message(None, req, 'tmp', '')
mock = mocker.Mock()
z.send_message(msg)
time.sleep(1)
z.recv_message(mock)

def callback(msg_):
pass

validate_callback(mock, callback)
print(f' joining pea')
pea.join()
print(f' joined pea')

def test_not_read_zmqlet():
with MockBasePeaNotRead(args3), Zmqlet(args1, default_logger) as z:

def test_decompressed_zmqlet(mocker):
with MockPea(args3) as pea, Zmqlet(args1, default_logger) as z:
req = Request()
req.request_id = random_identity()
d = req.data.docs.add()
d.tags['id'] = 2
msg = Message(None, req, 'tmp', '')

mock = mocker.Mock()
z.send_message(msg)
time.sleep(1)
z.recv_message(mock)

def callback(msg_):
pass

validate_callback(mock, callback)
print(f' joining pea')
pea.join()
print(f' joined pea')


def test_recv_message_zmqlet(mocker):
Expand Down

0 comments on commit a078c07

Please sign in to comment.