Skip to content

Commit

Permalink
Merge cf00d6f into 62b36c2
Browse files Browse the repository at this point in the history
  • Loading branch information
escapewindow committed Jul 16, 2018
2 parents 62b36c2 + cf00d6f commit 6f84055
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 23 deletions.
20 changes: 17 additions & 3 deletions scriptworker/context.py
Expand Up @@ -55,6 +55,7 @@ class Context(object):
temp_queue = None
_credentials = None
_claim_task = None # This assumes a single task per worker.
_event_loop = None
_temp_credentials = None # This assumes a single task per worker.
_reclaim_task = None
_projects = None
Expand Down Expand Up @@ -114,9 +115,7 @@ def create_queue(self, credentials):
"""
if credentials:
session = self.session or aiohttp.ClientSession(
loop=asyncio.get_event_loop()
)
session = self.session or aiohttp.ClientSession(loop=self.event_loop)
return Queue({
'credentials': credentials,
}, session=session)
Expand Down Expand Up @@ -186,6 +185,21 @@ def projects(self):
def projects(self, projects):
self._projects = projects

@property
def event_loop(self):
"""asyncio.BaseEventLoop: the running event loop.
This fixture mainly exists to allow for overrides during unit tests.
"""
if not self._event_loop:
self._event_loop = asyncio.get_event_loop()
return self._event_loop

@event_loop.setter
def event_loop(self, event_loop):
self._event_loop = event_loop

async def populate_projects(self, force=False):
"""Download the ``projects.yml`` file and populate ``self.projects``.
Expand Down
17 changes: 11 additions & 6 deletions scriptworker/gpg.py
Expand Up @@ -1358,13 +1358,12 @@ def build_gpg_homedirs_from_repo(
"""
basedir = basedir or context.config['base_gpg_home_dir']
repo_path = context.config['git_key_repo_dir']
event_loop = asyncio.get_event_loop()
# verify our input. Hardcoding the check before importing, as opposed
# to expecting something else to run the check for us.
# This currently runs twice, once to update to the tag and once before
# we build the homedirs, in case we ever call this function without calling
# ``update_signed_git_repo`` first.
event_loop.run_until_complete(verify_function(context, tag))
context.event_loop.run_until_complete(verify_function(context, tag))
rm(basedir)
makedirs(basedir)
# create gpg homedirs
Expand Down Expand Up @@ -1405,9 +1404,8 @@ def _update_git_and_rebuild_homedirs(context, basedir=None):
trusted_path
)
overwrite_gpg_home(tmp_gpg_home, guess_gpg_home(context))
event_loop = asyncio.get_event_loop()
old_revision = get_last_good_git_revision(context)
new_revision, tag = event_loop.run_until_complete(
new_revision, tag = context.event_loop.run_until_complete(
retry_async(
update_signed_git_repo, retry_exceptions=(ScriptWorkerRetryException, ),
args=(context, )
Expand Down Expand Up @@ -1483,16 +1481,21 @@ def rm_lockfile(context):
rm(context.config['gpg_lockfile'])


def rebuild_gpg_homedirs():
def rebuild_gpg_homedirs(event_loop=None):
"""Rebuild the gpg homedirs in the background.
This is an entry point, and should be called before scriptworker is run.
Args:
event_loop (asyncio.BaseEventLoop, optional): the event loop to use.
If None, use ``asyncio.get_event_loop()``. Defaults to None.
Raises:
SystemExit: on failure.
"""
context, _ = get_context_from_cmdln(sys.argv[1:])
context.event_loop = event_loop or context.event_loop
update_logging_config(context, file_name='rebuild_gpg_homedirs.log')
log.info("rebuild_gpg_homedirs()...")
basedir = get_tmp_base_gpg_home_dir(context)
Expand All @@ -1501,7 +1504,9 @@ def rebuild_gpg_homedirs():
create_lockfile(context)
new_revision = None
try:
new_revision = _update_git_and_rebuild_homedirs(context, basedir=basedir)
new_revision = _update_git_and_rebuild_homedirs(
context, basedir=basedir
)
except ScriptWorkerException as exc:
log.exception("Failed to run _update_git_and_rebuild_homedirs")
sys.exit(exc.exit_code)
Expand Down
3 changes: 2 additions & 1 deletion scriptworker/test/__init__.py
Expand Up @@ -270,7 +270,7 @@ async def _close_session(obj):

@pytest.mark.asyncio
@pytest.yield_fixture(scope='function', params=['firefox'])
async def rw_context(request):
async def rw_context(request, event_loop):
with tempfile.TemporaryDirectory() as tmp:
config = get_unfrozen_copy(DEFAULT_CONFIG)
config['cot_product'] = request.param
Expand All @@ -285,6 +285,7 @@ async def rw_context(request):
if key.endswith("key_path") or key in ("gpg_home", ):
context.config[key] = os.path.join(tmp, key)
context.config['verbose'] = VERBOSE
context.event_loop = event_loop
yield context
await _close_session(context)
await _close_session(context.queue)
Expand Down
22 changes: 22 additions & 0 deletions scriptworker/test/test_context.py
Expand Up @@ -2,7 +2,9 @@
# coding=utf-8
"""Test scriptworker.context
"""
import asyncio
import json
import mock
import os
import pytest
import taskcluster
Expand Down Expand Up @@ -124,3 +126,23 @@ def test_get_credentials(context):
expected = {'asdf': 'foobar'}
context._credentials = expected
assert context.credentials == expected


def test_new_event_loop(mocker):
"""The default context.event_loop is from `asyncio.get_event_loop`"""
fake_loop = mock.MagicMock()
mocker.patch.object(asyncio, 'get_event_loop', return_value=fake_loop)
context = swcontext.Context()
assert context.event_loop == fake_loop


def test_set_event_loop(mocker):
"""`context.event_loop` returns the same value once set.
(This may seem obvious, but this tests the correctness of the property.)
"""
fake_loop = mock.MagicMock()
context = swcontext.Context()
context.event_loop = fake_loop
assert context.event_loop == fake_loop
6 changes: 3 additions & 3 deletions scriptworker/test/test_gpg.py
Expand Up @@ -767,7 +767,7 @@ async def new_revision(*args, **kwargs):
mocker.patch.object(sgpg, "build_gpg_homedirs_from_repo", new=noop_sync)
mocker.patch.object(sgpg, "write_last_good_git_revision", new=noop_sync)

sgpg.rebuild_gpg_homedirs()
sgpg.rebuild_gpg_homedirs(event_loop=context.event_loop)


@pytest.mark.parametrize("nuke_dir", (True, False))
Expand All @@ -790,7 +790,7 @@ def fake_context(*args):
mocker.patch.object(sgpg, "build_gpg_homedirs_from_repo", new=noop_sync)

with pytest.raises(SystemExit):
sgpg.rebuild_gpg_homedirs()
sgpg.rebuild_gpg_homedirs(event_loop=context.event_loop)


def test_rebuild_gpg_homedirs_lockfile(context, mocker):
Expand All @@ -802,7 +802,7 @@ def fake_context(*args):
mocker.patch.object(sgpg, "update_logging_config", new=noop_sync)

touch(context.config['gpg_lockfile'])
sgpg.rebuild_gpg_homedirs()
sgpg.rebuild_gpg_homedirs(event_loop=context.event_loop)


# last_good_git_revision {{{1
Expand Down
8 changes: 4 additions & 4 deletions scriptworker/test/test_worker.py
Expand Up @@ -33,7 +33,7 @@ def context(rw_context):


# main {{{1
def test_main(mocker, context):
def test_main(mocker, context, event_loop):
config = dict(context.config)
config['poll_interval'] = 1
creds = {'fake_creds': True}
Expand All @@ -52,12 +52,12 @@ async def foo(arg, credentials):
mocker.patch.object(worker, 'async_main', new=foo)
mocker.patch.object(sys, 'argv', new=['x', tmp])
with pytest.raises(ScriptWorkerException):
worker.main()
worker.main(event_loop=event_loop)
finally:
os.remove(tmp)


def test_main_sigterm(mocker, context):
def test_main_sigterm(mocker, context, event_loop):
"""Test that sending SIGTERM causes the main loop to stop after the next
call to async_main."""
config = dict(context.config)
Expand All @@ -77,7 +77,7 @@ async def async_main(*args):
del(config['credentials'])
mocker.patch.object(worker, 'async_main', new=async_main)
mocker.patch.object(sys, 'argv', new=['x', tmp])
worker.main()
worker.main(event_loop=event_loop)
finally:
os.remove(tmp)

Expand Down
17 changes: 11 additions & 6 deletions scriptworker/worker.py
Expand Up @@ -111,7 +111,6 @@ async def run_tasks(context, creds_key="credentials"):
None: if no task run.
"""
loop = asyncio.get_event_loop()
tasks = await claim_work(context)
status = None
if not tasks or not tasks.get('tasks', []):
Expand All @@ -124,7 +123,7 @@ async def run_tasks(context, creds_key="credentials"):
for task_defn in tasks.get('tasks', []):
status = 0
prepare_to_run_task(context, task_defn)
reclaim_fut = loop.create_task(reclaim_task(context, context.task))
reclaim_fut = context.event_loop.create_task(reclaim_task(context, context.task))
status = await do_run_task(context)
status = worst_level(status, await do_upload(context))
await complete_task(context, status)
Expand Down Expand Up @@ -158,12 +157,18 @@ async def async_main(context, credentials):


# main {{{1
def main():
"""Scriptworker entry point: get everything set up, then enter the main loop."""
def main(event_loop=None):
"""Scriptworker entry point: get everything set up, then enter the main loop.
Args:
event_loop (asyncio.BaseEventLoop, optional): the event loop to use.
If None, use ``asyncio.get_event_loop()``. Defaults to None.
"""
context, credentials = get_context_from_cmdln(sys.argv[1:])
log.info("Scriptworker starting up at {} UTC".format(arrow.utcnow().format()))
cleanup(context)
loop = asyncio.get_event_loop()
context.event_loop = event_loop or asyncio.get_event_loop()

done = False

Expand All @@ -176,7 +181,7 @@ def _handle_sigterm(signum, frame):

while not done:
try:
loop.run_until_complete(async_main(context, credentials))
context.event_loop.run_until_complete(async_main(context, credentials))
except Exception:
log.critical("Fatal exception", exc_info=1)
raise

0 comments on commit 6f84055

Please sign in to comment.