Skip to content

Commit

Permalink
Merge pull request #230 from escapewindow/concurrent-tests
Browse files Browse the repository at this point in the history
address event loop bustage in concurrent tests
  • Loading branch information
escapewindow committed Jun 1, 2018
2 parents cace69c + 31d37ab commit dd7b974
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 28 deletions.
12 changes: 8 additions & 4 deletions scriptworker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def callback(match):
return unquote(filepath).lstrip('/')


def sync_main(async_main, config_path=None, default_config=None, should_validate_task=True):
def sync_main(async_main, config_path=None, default_config=None,
should_validate_task=True, loop_function=asyncio.get_event_loop):
"""Entry point for scripts using scriptworker.
This function sets up the basic needs for a script to run. More specifically:
Expand All @@ -145,18 +146,21 @@ def sync_main(async_main, config_path=None, default_config=None, should_validate
Args:
async_main (function): The function to call once everything is set up
config_path (str, optional): The path to the file to load the config from.
Loads from `sys.argv[1]` if `None`. Defaults to None.
default_config (dict, optional): the default config to use for `_init_context`.
Loads from ``sys.argv[1]`` if ``None``. Defaults to None.
default_config (dict, optional): the default config to use for ``_init_context``.
defaults to None.
should_validate_task (bool, optional): whether we should validate the task
schema. Defaults to True.
loop_function (function, optional): the function to call to get the
event loop; here for testing purposes. Defaults to
``asyncio.get_event_loop``.
"""
context = _init_context(config_path, default_config)
_init_logging(context)
if should_validate_task:
validate_task_schema(context)
loop = asyncio.get_event_loop()
loop = loop_function()
loop.run_until_complete(_handle_asyncio_loop(async_main, context))


Expand Down
26 changes: 13 additions & 13 deletions scriptworker/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def __init__(self, *args, status=200, payload=None, **kwargs):
self._connection = mock.MagicMock()
self._payload = payload or {}
self.status = status
self.headers = {'content-type': 'application/json'}
self._headers = {'content-type': 'application/json'}
self._cache = {}
self._loop = mock.MagicMock()
self.content = self
self.resp = [b"asdf", b"asdf"]
Expand Down Expand Up @@ -165,35 +166,34 @@ def unsuccessful_queue():
return UnsuccessfulQueue()


@pytest.mark.asyncio
@pytest.fixture(scope='function')
def fake_session():
async def fake_session():
@asyncio.coroutine
def _fake_request(method, url, *args, **kwargs):
resp = FakeResponse(method, url)
resp._history = (FakeResponse(method, url, status=302),)
return resp

loop = asyncio.get_event_loop()
session = aiohttp.ClientSession(loop=loop)
session = aiohttp.ClientSession()
session._request = _fake_request
yield session
loop = asyncio.get_event_loop()
loop.run_until_complete(session.close())
await session.close()


@pytest.mark.asyncio
@pytest.fixture(scope='function')
def fake_session_500(event_loop):
async def fake_session_500():
@asyncio.coroutine
def _fake_request(method, url, *args, **kwargs):
resp = FakeResponse(method, url, status=500)
resp._history = (FakeResponse(method, url, status=302),)
return resp

loop = asyncio.get_event_loop()
session = aiohttp.ClientSession(loop=loop)
session = aiohttp.ClientSession()
session._request = _fake_request
yield session
loop.run_until_complete(session.close())
await session.close()


def integration_create_task_payload(config, task_group_id, scopes=None,
Expand Down Expand Up @@ -270,8 +270,9 @@ def tmpdir2():
yield tmp


@pytest.mark.asyncio
@pytest.yield_fixture(scope='function', params=['firefox'])
def rw_context(request):
async def rw_context(request):
with tempfile.TemporaryDirectory() as tmp:
config = get_unfrozen_copy(DEFAULT_CONFIG)
config['cot_product'] = request.param
Expand All @@ -288,8 +289,7 @@ def rw_context(request):
context.config['verbose'] = VERBOSE
yield context
try:
loop = asyncio.get_event_loop()
loop.run_until_complete(context.session.close())
await context.session.close()
except:
pass

Expand Down
26 changes: 20 additions & 6 deletions scriptworker/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,26 @@ def test_bad_artifact_url(valid_artifact_rules, valid_artifact_task_ids, url):
client.validate_artifact_url(valid_artifact_rules, valid_artifact_task_ids, url)


@pytest.mark.asyncio
@pytest.mark.parametrize('should_validate_task', (True, False))
def test_sync_main_runs_fully(config, event_loop, should_validate_task):
async def test_sync_main_runs_fully(config, event_loop, should_validate_task):
copyfile(BASIC_TASK, os.path.join(config['work_dir'], 'task.json'))
generator = (n for n in range(0, 2))
async_main_calls = []
run_until_complete_calls = []

async def async_main(*args):
async_main_calls.append(args)

def count_run_until_complete(arg1):
run_until_complete_calls.append(arg1)

fake_loop = MagicMock()
fake_loop.run_until_complete = count_run_until_complete

async def async_main(_):
next(generator)
def loop_function():
return fake_loop

kwargs = {}
kwargs = {'loop_function': loop_function}

if should_validate_task:
schema_path = os.path.join(config['work_dir'], 'schema.json')
Expand All @@ -232,7 +243,10 @@ async def async_main(_):
kwargs['config_path'] = f.name
client.sync_main(async_main, **kwargs)

assert next(generator) == 1 # async_main was called once
for i in run_until_complete_calls:
await i # suppress coroutine not awaited warning
assert len(run_until_complete_calls) == 1 # run_until_complete was called once
assert len(async_main_calls) == 1 # async_main was called once


@pytest.mark.parametrize('does_use_argv, default_config', (
Expand Down
6 changes: 1 addition & 5 deletions scriptworker/test/test_cot_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,16 +1696,13 @@ async def maybe_die(*args):

# verify_cot_cmdln {{{1
@pytest.mark.parametrize("args", (("x", "--task-type", "signing", "--cleanup"), ("x", "--task-type", "balrog")))
def test_verify_cot_cmdln(chain, args, tmpdir, mocker, event_loop):
def test_verify_cot_cmdln(chain, args, tmpdir, mocker):
context = mock.MagicMock()
context.queue = mock.MagicMock()
context.queue.task = noop_async
path = os.path.join(tmpdir, 'x')
makedirs(path)

def eloop():
return event_loop

def get_context():
return context

Expand All @@ -1719,7 +1716,6 @@ def cot(*args, **kwargs):
return m

mocker.patch.object(tempfile, 'mkdtemp', new=mkdtemp)
mocker.patch.object(asyncio, 'get_event_loop', new=eloop)
mocker.patch.object(cotverify, 'read_worker_creds', new=noop_sync)
mocker.patch.object(cotverify, 'Context', new=get_context)
mocker.patch.object(cotverify, 'ChainOfTrust', new=cot)
Expand Down

0 comments on commit dd7b974

Please sign in to comment.