Skip to content

Commit

Permalink
Merge 703cf19 into 04b68a8
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito committed Jul 30, 2018
2 parents 04b68a8 + 703cf19 commit 1ca49d5
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 6 deletions.
1 change: 1 addition & 0 deletions aiomisc/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def start():
yield loop
except Exception as e:
loop.run_until_complete(graceful_shutdown(services, loop, e))
raise
else:
loop.run_until_complete(graceful_shutdown(services, loop, None))
finally:
Expand Down
88 changes: 87 additions & 1 deletion aiomisc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging.handlers
import socket
from multiprocessing import cpu_count
from typing import Iterable, Any, Tuple
from typing import Iterable, Any, Tuple, Coroutine, List

import uvloop

Expand Down Expand Up @@ -73,3 +73,89 @@ def new_event_loop(pool_size=None) -> asyncio.AbstractEventLoop:
asyncio.set_event_loop(loop)

return loop


_TASKS_LIST = List[asyncio.Task]


def wait_for(*coros: Tuple[Coroutine, ...],
raise_first: bool = True,
cancel: bool = True,
loop: asyncio.AbstractEventLoop = None):

tasks = list() # type: _TASKS_LIST
loop = loop or asyncio.get_event_loop() # type: asyncio.AbstractEventLoop
result_future = loop.create_future() # type: asyncio.Future
waiting = len(coros)

def cancel_pending():
nonlocal result_future
nonlocal tasks

for t in tasks:
if t.done():
continue

t.cancel()

def raise_first_exception(exc: Exception):
nonlocal result_future
nonlocal tasks

if result_future.done():
return

result_future.set_exception(exc)

def return_result():
nonlocal result_future
nonlocal tasks

if result_future.done():
return

results = []

for task in tasks:
exc = task.exception()

results.append(task.result() if exc is None else exc)

result_future.set_result(results)

def done_callback(t: asyncio.Future):
nonlocal tasks
nonlocal result_future
nonlocal waiting

waiting -= 1

exc = t.exception()

if t.cancelled() or exc is None:
if waiting == 0:
return_result()

return

if raise_first:
raise_first_exception(exc)

if waiting == 0:
return_result()

for coro in coros:
task = loop.create_task(coro)
task.add_done_callback(done_callback)
tasks.append(task)

async def run():
nonlocal result_future

try:
return await result_future
finally:
if cancel:
cancel_pending()

return run()
10 changes: 6 additions & 4 deletions tests/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ async def stop(self, err: Exception = None):
DummyService(running=False, stopped=False),
)

with entrypoint(*services):
raise RuntimeError
with pytest.raises(RuntimeError):
with entrypoint(*services):
raise RuntimeError

for svc in services:
assert svc.running
Expand All @@ -94,8 +95,9 @@ async def stop(self, err: Exception = None):
StartingService(running=False),
)

with entrypoint(*services):
raise RuntimeError
with pytest.raises(RuntimeError):
with entrypoint(*services):
raise RuntimeError

for svc in services:
assert svc.running
Expand Down
67 changes: 66 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import logging
import socket
Expand All @@ -6,8 +7,9 @@

import pytest

from aiomisc.utils import chunk_list, bind_socket
from aiomisc.entrypoint import entrypoint
from aiomisc.log import basic_config
from aiomisc.utils import bind_socket, chunk_list, wait_for


def test_chunk_list(event_loop):
Expand Down Expand Up @@ -59,3 +61,66 @@ def test_bind_address(address, family, unused_tcp_port):

assert isinstance(sock, socket.socket)
assert sock.family == family


def test_wait_for_dummy():
with entrypoint() as loop:
results = loop.run_until_complete(
wait_for(*[asyncio.sleep(0.1) for _ in range(100)])
)

assert len(results) == 100
assert results == [None] * 100


def test_wait_for_exception():
async def coro(arg):
await asyncio.sleep(0.1)
assert arg != 15
return arg

with entrypoint() as loop:
with pytest.raises(AssertionError):
loop.run_until_complete(
wait_for(*[coro(i) for i in range(100)])
)

results = loop.run_until_complete(
wait_for(
*[coro(i) for i in range(17)],
raise_first=False
),
)

assert results
assert len(results) == 17
assert isinstance(results[15], AssertionError)
assert results[:15] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
assert results[16:] == [16]


def test_wait_for_cancelling():
results = []

async def coro(arg):
nonlocal results
await asyncio.sleep(0.1)
assert arg != 15

if arg > 15:
await asyncio.sleep(1)

results.append(arg)

with entrypoint() as loop:
with pytest.raises(AssertionError):
loop.run_until_complete(
wait_for(*[coro(i) for i in range(100)])
)

loop.run_until_complete(asyncio.sleep(2))

assert results
assert len(results) == 15
assert len(set(results)) == 15
assert frozenset(results) == frozenset(range(15))

0 comments on commit 1ca49d5

Please sign in to comment.