diff --git a/docs/changelog.rst b/docs/changelog.rst index 664da4dbf4..24c742083f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,6 +25,12 @@ would have also been included in the 1.2 line. Changelog ========= +* :bug:`568` `~fabric.tasks.execute` allowed too much of its internal state + changes (to variables such as ``env.host_string`` and ``env.parallel``) to + persist after execution completed; this caused a number of different + incorrect behaviors. `~fabric.tasks.execute` has been overhauled to clean up + its own state changes -- while preserving any state changes made by the task + being executed. * :bug:`584` `~fabric.contrib.project.upload_project` did not take explicit remote directory location into account when untarring, and now uses `~fabric.context_managers.cd` to address this. Thanks to Ben Burry for the diff --git a/fabric/context_managers.py b/fabric/context_managers.py index 2bf81d7831..2d58c66c3f 100644 --- a/fabric/context_managers.py +++ b/fabric/context_managers.py @@ -10,7 +10,8 @@ from contextlib import contextmanager, nested import sys -from fabric.state import env, output, win32 +from fabric.state import output, win32 +from fabric import state if not win32: import termios @@ -84,20 +85,29 @@ def _setenv(**kwargs): This context manager is used internally by `settings` and is not intended to be used directly. """ + clean_revert = kwargs.pop('clean_revert', False) previous = {} new = [] for key, value in kwargs.iteritems(): - if key in env: - previous[key] = env[key] + if key in state.env: + previous[key] = state.env[key] else: new.append(key) - env[key] = value + state.env[key] = value try: yield finally: - env.update(previous) - for key in new: - del env[key] + if clean_revert: + for key, value in kwargs.iteritems(): + # If the current env value for this key still matches the + # value we set it to beforehand, we are OK to revert it to the + # pre-block value. + if value == state.env[key]: + state.env[key] = previous[key] + else: + state.env.update(previous) + for key in new: + del state.env[key] def settings(*args, **kwargs): @@ -109,6 +119,9 @@ def settings(*args, **kwargs): * Most usefully, it allows temporary overriding/updating of ``env`` with any provided keyword arguments, e.g. ``with settings(user='foo'):``. Original values, if any, will be restored once the ``with`` block closes. + * The keyword argument ``clean_revert`` has special meaning for + ``settings`` itself (see below) and will be stripped out before + execution. * In addition, it will use `contextlib.nested`_ to nest any given non-keyword arguments, which should be other context managers, e.g. ``with settings(hide('stderr'), show('stdout')):``. @@ -139,6 +152,41 @@ def my_task(): variables in tandem with hiding (or showing) specific levels of output, or in tandem with any other piece of Fabric functionality implemented as a context manager. + + If ``clean_revert`` is set to ``True``, ``settings`` will **not** revert + keys which are altered within the nested block, instead only reverting keys + whose values remain the same as those given. More examples will make this + clear; below is how ``settings`` operates normally:: + + # Before the block, env.parallel defaults to False, host_string to None + with settings(parallel=True, host_string='myhost'): + # env.parallel is True + # env.host_string is 'myhost' + env.host_string = 'otherhost' + # env.host_string is now 'otherhost' + # Outside the block: + # * env.parallel is False again + # * env.host_string is None again + + The internal modification of ``env.host_string`` is nullified -- not always + desirable. That's where ``clean_revert`` comes in:: + + # Before the block, env.parallel defaults to False, host_string to None + with settings(parallel=True, host_string='myhost', clean_revert=True): + # env.parallel is True + # env.host_string is 'myhost' + env.host_string = 'otherhost' + # env.host_string is now 'otherhost' + # Outside the block: + # * env.parallel is False again + # * env.host_string remains 'otherhost' + + Brand new keys which did not exist in ``env`` prior to using ``settings`` + are also preserved if ``clean_revert`` is active. When ``False``, such keys + are removed when the block exits. + + .. versionadded:: 1.4.1 + The ``clean_revert`` kwarg. """ managers = list(args) if kwargs: @@ -225,8 +273,8 @@ def lcd(path): def _change_cwd(which, path): path = path.replace(' ', '\ ') - if env.get(which) and not path.startswith('/'): - new_cwd = env.get(which) + '/' + path + if state.env.get(which) and not path.startswith('/'): + new_cwd = state.env.get(which) + '/' + path else: new_cwd = path return _setenv(**{which: new_cwd}) @@ -315,7 +363,7 @@ def prefix(command): Contrived, but hopefully illustrative. """ - return _setenv(command_prefixes=env.command_prefixes + [command]) + return _setenv(command_prefixes=state.env.command_prefixes + [command]) @contextmanager diff --git a/fabric/job_queue.py b/fabric/job_queue.py index f53740922e..68ad3612e6 100644 --- a/fabric/job_queue.py +++ b/fabric/job_queue.py @@ -5,11 +5,13 @@ items, though within Fabric itself only ``Process`` objects are used/supported. """ +from __future__ import with_statement import time import Queue from fabric.state import env from fabric.network import ssh +from fabric.context_managers import settings class JobQueue(object): @@ -115,8 +117,8 @@ def _advance_the_queue(): job = self._queued.pop() if self._debug: print("Popping '%s' off the queue and starting it" % job.name) - env.host_string = env.host = job.name - job.start() + with settings(clean_revert=True, host_string=job.name, host=job.name): + job.start() self._running.append(job) if not self._closed: diff --git a/fabric/tasks.py b/fabric/tasks.py index 2ff5aad17d..bf7f2cb18d 100644 --- a/fabric/tasks.py +++ b/fabric/tasks.py @@ -83,8 +83,7 @@ def get_pool_size(self, hosts, default): pool_size = min((pool_size, len(hosts))) # Inform user of final pool size for this task if state.output.debug: - msg = "Parallel tasks now using pool size of %d" - print msg % pool_size + print "Parallel tasks now using pool size of %d" % pool_size return pool_size @@ -149,53 +148,53 @@ def _execute(task, host, my_env, args, kwargs, jobs, queue, multiprocessing): # Create per-run env with connection settings local_env = to_dict(host) local_env.update(my_env) - state.env.update(local_env) - # Handle parallel execution - if queue is not None: # Since queue is only set for parallel - # Set a few more env flags for parallelism - state.env.parallel = True # triggers some extra aborts, etc - state.env.linewise = True # to mirror -P behavior - name = local_env['host_string'] - # Wrap in another callable that: - # * nukes the connection cache to prevent shared-access problems - # * knows how to send the tasks' return value back over a Queue - # * captures exceptions raised by the task - def inner(args, kwargs, queue, name): - key = normalize_to_string(state.env.host_string) - state.connections.pop(key, "") - try: - result = task.run(*args, **kwargs) - except BaseException, e: # We really do want to capture everything - result = e - # But still print it out, otherwise users won't know what the - # fuck. Especially if the task is run at top level and nobody's - # doing anything with the return value. - # BUT don't do this if it's a SystemExit as that implies use of - # abort(), which does its own printing. - if e.__class__ is not SystemExit: - print >> sys.stderr, "!!! Parallel execution exception under host %r:" % name - sys.excepthook(*sys.exc_info()) - # Conversely, if it IS SystemExit, we can raise it to ensure a - # correct return value. - else: - raise - queue.put({'name': name, 'result': result}) - - # Stuff into Process wrapper - kwarg_dict = { - 'args': args, - 'kwargs': kwargs, - 'queue': queue, - 'name': name - } - p = multiprocessing.Process(target=inner, kwargs=kwarg_dict) - # Name/id is host string - p.name = name - # Add to queue - jobs.append(p) - # Handle serial execution - else: - return task.run(*args, **kwargs) + # Set a few more env flags for parallelism + if queue is not None: + local_env.update({'parallel': True, 'linewise': True}) + with settings(**local_env): + # Handle parallel execution + if queue is not None: # Since queue is only set for parallel + name = local_env['host_string'] + # Wrap in another callable that: + # * nukes the connection cache to prevent shared-access problems + # * knows how to send the tasks' return value back over a Queue + # * captures exceptions raised by the task + def inner(args, kwargs, queue, name): + try: + key = normalize_to_string(state.env.host_string) + state.connections.pop(key, "") + result = task.run(*args, **kwargs) + except BaseException, e: # We really do want to capture everything + result = e + # But still print it out, otherwise users won't know what the + # fuck. Especially if the task is run at top level and nobody's + # doing anything with the return value. + # BUT don't do this if it's a SystemExit as that implies use of + # abort(), which does its own printing. + if e.__class__ is not SystemExit: + print >> sys.stderr, "!!! Parallel execution exception under host %r:" % name + sys.excepthook(*sys.exc_info()) + # Conversely, if it IS SystemExit, we can raise it to ensure a + # correct return value. + else: + raise + queue.put({'name': name, 'result': result}) + + # Stuff into Process wrapper + kwarg_dict = { + 'args': args, + 'kwargs': kwargs, + 'queue': queue, + 'name': name + } + p = multiprocessing.Process(target=inner, kwargs=kwarg_dict) + # Name/id is host string + p.name = name + # Add to queue + jobs.append(p) + # Handle serial execution + else: + return task.run(*args, **kwargs) def _is_task(task): return isinstance(task, Task) @@ -242,7 +241,7 @@ def execute(task, *args, **kwargs): Added the return value mapping; previously this function had no defined return value. """ - my_env = {} + my_env = {'clean_revert': True} results = {} # Obtain task is_callable = callable(task) @@ -324,7 +323,8 @@ def execute(task, *args, **kwargs): # Or just run once for local-only else: - state.env.update(my_env) - results[''] = task.run(*args, **new_kwargs) + with settings(**my_env): + results[''] = task.run(*args, **new_kwargs) # Return what we can from the inner task executions + return results diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 8ed00ea437..df1d8f6272 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -91,3 +91,15 @@ def test_settings_with_other_context_managers(): ok_(env.testval1, "outer 1") eq_(env.lcwd, prev_lcwd) + +def test_settings_clean_revert(): + """ + settings(clean_revert=True) should only revert values matching input values + """ + env.modified = "outer" + env.notmodified = "outer" + with settings(modified="inner", notmodified="inner", clean_revert=True): + eq_(env.modified, "inner") + eq_(env.notmodified, "inner") + env.modified = "modified internally" + eq_(env.modified, "modified internally") diff --git a/tests/test_tasks.py b/tests/test_tasks.py index c2ff614fcf..a8d5002207 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -401,3 +401,49 @@ class MyTask(Task): run = Fake(callable=True, expect_call=True) mytask = MyTask() execute(mytask) + + +class TestExecuteEnvInteractions(FabricTest): + def set_network(self): + # Don't update env.host/host_string/etc + pass + + @server(port=2200) + @server(port=2201) + def test_should_not_mutate_its_own_env_vars(self): + """ + internal env changes should not bleed out, but task env changes should + """ + # Task that uses a handful of features which involve env vars + @parallel + @hosts('username@127.0.0.1:2200', 'username@127.0.0.1:2201') + def mytask(): + run("ls /simple") + # Pre-assertions + assertions = { + 'parallel': False, + 'all_hosts': [], + 'host': None, + 'hosts': [], + 'host_string': None + } + for key, value in assertions.items(): + eq_(env[key], value) + # Run + with hide('everything'): + result = execute(mytask) + eq_(len(result), 2) + # Post-assertions + for key, value in assertions.items(): + eq_(env[key], value) + + @server() + def test_should_allow_task_to_modify_env_vars(self): + @hosts('username@127.0.0.1:2200') + def mytask(): + run("ls /simple") + env.foo = "bar" + with hide('everything'): + execute(mytask) + eq_(env.foo, "bar") + eq_(env.host_string, None) diff --git a/tests/utils.py b/tests/utils.py index 5e4b4c0911..57c4be2ae3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,10 +45,13 @@ def setup(self): # Temporary local file dir self.tmpdir = tempfile.mkdtemp() + def set_network(self): + env.update(to_dict('%s@%s:%s' % (USER, HOST, PORT))) + def env_setup(self): # Set up default networking for test server env.disable_known_hosts = True - env.update(to_dict('%s@%s:%s' % (USER, HOST, PORT))) + self.set_network() env.password = PASSWORDS[USER] # Command response mocking is easier without having to account for # shell wrapping everywhere.