Permalink
Browse files

Overhaul decorators to use func attrs vs registries

Re #19
  • Loading branch information...
bitprophet committed Sep 20, 2011
1 parent 28bcb92 commit 4f6892117dba3df6a239925123f73d5af48c8e2d
Showing with 80 additions and 73 deletions.
  1. +12 −23 fabric/decorators.py
  2. +14 −11 fabric/main.py
  3. +44 −30 tests/test_decorators.py
  4. +2 −9 tests/test_main.py
  5. +8 −0 tests/utils.py
View
@@ -136,32 +136,29 @@ def decorated(*args, **kwargs):
return serial(decorated)
-_serial = set()
def serial(func):
"""
Forces the wrapped function to always run sequentially, never in parallel.
- This decorator takes precedence over the global value of
- :ref:`env.parallel <env-parallel>`.
+ This decorator takes precedence over the global value of :ref:`env.parallel
+ <env-parallel>`. However, if a task is decorated with both
+ `~fabric.decorators.serial` *and* `~fabric.decorators.parallel`,
+ `~fabric.decorators.parallel` wins.
.. versionadded:: 1.3
"""
- # Register
- _serial.add(func.func_name)
- _parallel.discard(func.func_name)
+ if not getattr(func, 'parallel', False):
+ func.serial = True
return func
-def is_serial(func):
- return func.func_name in _serial
-
-_parallel = set()
def parallel(pool_size=None):
"""
Forces the wrapped function to run in parallel, instead of sequentially.
- This decorator takes precedence over the global value of
- :ref:`env.parallel <env-parallel>`.
+ This decorator takes precedence over the global value of :ref:`env.parallel
+ <env-parallel>`. It also takes precedence over `~fabric.decorators.serial`
+ if a task is decorated with both.
.. versionadded:: 1.3
"""
@@ -171,11 +168,9 @@ def inner(*args, **kwargs):
# Required for Paramiko/PyCrypto to be happy in multiprocessing
Random.atfork()
return func(*args, **kwargs)
- # Register
- _parallel.add(func.func_name)
- _serial.discard(func.func_name)
- # Tell function what its pool size is
- inner._pool_size = pool_size
+ inner.parallel = True
+ inner.serial = False
+ inner.pool_size = pool_size
return inner
# Allow non-factory-style decorator use (@decorator vs @decorator())
@@ -184,12 +179,6 @@ def inner(*args, **kwargs):
return real_decorator
-def is_parallel(func):
- return func.func_name in _parallel
-
-def needs_multiprocessing():
- return _parallel != set()
-
def with_settings(**kw_settings):
"""
View
@@ -22,7 +22,6 @@
from fabric.state import commands, connections, env_options
from fabric.tasks import Task
from fabric.utils import abort, indent
-from fabric.decorators import is_parallel, is_serial, needs_multiprocessing
from job_queue import JobQueue
# One-time calculation of "all internal callables" to avoid doing this on every
@@ -625,21 +624,17 @@ def update_output_levels(show, hide):
def requires_parallel(task):
"""
- Returns True if given ``task`` can and should be run in parallel mode.
+ Returns True if given ``task`` should be run in parallel mode.
Specifically:
- * The ``multiprocessing`` module is loaded, and:
* It's been explicitly marked with ``@parallel``, or:
* It's *not* been explicitly marked with ``@serial`` *and* the global
parallel option (``env.parallel``) is set to ``True``.
"""
return (
- ('multiprocessing' in sys.modules)
- and (
- (state.env.parallel and not is_serial(task))
- or is_parallel(task)
- )
+ (state.env.parallel and not getattr(task, 'serial', False))
+ or getattr(task, 'parallel', False)
)
@@ -656,7 +651,7 @@ def _get_pool_size(task, hosts):
# change)
default_pool_size = state.env.pool_size or len(hosts)
# Allow per-task override
- pool_size = getattr(task, '_pool_size', default_pool_size)
+ pool_size = getattr(task, 'pool_size', default_pool_size)
# But ensure it's never larger than the number of hosts
pool_size = min((pool_size, len(hosts)))
# Inform user of final pool size for this task
@@ -666,6 +661,13 @@ def _get_pool_size(task, hosts):
return pool_size
+def _parallel_tasks(commands_to_run):
+ return any(map(
+ lambda x: requires_parallel(crawl(x[0], state.commands)),
+ commands_to_run
+ ))
+
+
def main():
"""
Main command-line execution loop.
@@ -789,7 +791,7 @@ def main():
print("Commands to run: %s" % names)
# Import multiprocessing if needed, erroring out usefully if it can't.
- if state.env.parallel or needs_multiprocessing():
+ if state.env.parallel or _parallel_tasks(commands_to_run):
try:
import multiprocessing
except ImportError, e:
@@ -826,7 +828,8 @@ def main():
print("[%s] Executing task '%s'" % (host, name))
# Handle parallel execution
- if requires_parallel(task):
+ have_multiprocessing = 'multiprocessing' in sys.modules
+ if requires_parallel(task) and have_multiprocessing:
# Grab appropriate callable (func or instance method)
to_call = task
if hasattr(task, 'run') and callable(task.run):
View
@@ -1,11 +1,15 @@
+from __future__ import with_statement
+
import random
from nose.tools import eq_, ok_, assert_true, assert_false, assert_equal
import fudge
-from fudge import Fake, with_fakes
+from fudge import Fake, with_fakes, patched_context
from fabric import decorators, tasks
from fabric.state import env
+import fabric # for patching fabric.state.xxx
+from fabric.main import _parallel_tasks, requires_parallel
#
@@ -138,7 +142,6 @@ def single_run():
pass
def test_runs_once():
- assert_true(decorators.is_serial(single_run))
assert_false(hasattr(single_run, 'return_value'))
single_run()
assert_true(hasattr(single_run, 'return_value'))
@@ -150,6 +153,7 @@ def test_runs_once():
# @serial / @parallel
#
+
@decorators.serial
def serial():
pass
@@ -159,43 +163,53 @@ def serial():
def serial2():
pass
-def test_serial():
- assert_true(decorators.is_serial(serial))
- assert_false(decorators.is_parallel(serial))
- serial()
-
- assert_true(decorators.is_serial(serial2))
- assert_false(decorators.is_parallel(serial2))
- serial2()
-
-
@decorators.parallel
-def parallel():
+@decorators.serial
+def serial3():
pass
@decorators.parallel
-@decorators.serial
-def parallel2():
+def parallel():
pass
@decorators.parallel(pool_size=20)
-def parallel3():
+def parallel2():
pass
-def test_parallel():
- assert_true(decorators.is_parallel(parallel))
- assert_false(decorators.is_serial(parallel))
- parallel()
-
- assert_true(decorators.is_parallel(parallel2))
- assert_false(decorators.is_serial(parallel2))
- parallel2()
-
- assert_true(decorators.is_parallel(parallel))
- assert_false(decorators.is_serial(parallel))
- assert_equal(parallel3._pool_size, 20)
- assert_equal(getattr(parallel3, '_pool_size'), 20)
-
+fake_tasks = {
+ 'serial': serial,
+ 'serial2': serial2,
+ 'serial3': serial3,
+ 'parallel': parallel,
+ 'parallel2': parallel2,
+}
+
+def parallel_task_helper(actual_tasks, expected):
+ commands_to_run = map(lambda x: [x], actual_tasks)
+ with patched_context(fabric.state, 'commands', fake_tasks):
+ eq_(_parallel_tasks(commands_to_run), expected)
+
+def test_parallel_tasks():
+ for desc, task_names, expected in (
+ ("One @serial-decorated task == no parallelism",
+ ['serial'], False),
+ ("One @parallel-decorated task == parallelism",
+ ['parallel'], True),
+ ("One @parallel-decorated and one @serial-decorated task == paralellism",
+ ['parallel', 'serial'], True),
+ ("Tasks decorated with both @serial and @parallel count as @parallel",
+ ['serial2', 'serial3'], True)
+ ):
+ parallel_task_helper.description = desc
+ yield parallel_task_helper, task_names, expected
+ del parallel_task_helper.description
+
+def test_parallel_wins_vs_serial():
+ """
+ @parallel takes precedence over @serial when both are used on one task
+ """
+ ok_(requires_parallel(serial2))
+ ok_(requires_parallel(serial3))
#
View
@@ -17,7 +17,8 @@
from fabric.state import _AttributeDict
from fabric.tasks import Task
-from utils import mock_streams, patched_env, eq_, FabricTest, fabfile
+from utils import (mock_streams, patched_env, eq_, FabricTest, fabfile,
+ path_prefix)
# Stupid load_fabfile wrapper to hide newly added return value.
@@ -315,14 +316,6 @@ def test_load_fabfile_should_not_remove_real_path_elements():
# Namespacing and new-style tasks
#
-@contextmanager
-def path_prefix(module):
- i = 0
- sys.path.insert(i, os.path.dirname(module))
- yield
- sys.path.pop(i)
-
-
class TestTaskAliases(FabricTest):
def test_flat_alias(self):
f = fabfile("flat_alias.py")
View
@@ -260,3 +260,11 @@ def wrapper(func):
def fabfile(name):
return os.path.join(os.path.dirname(__file__), 'support', name)
+
+
+@contextmanager
+def path_prefix(module):
+ i = 0
+ sys.path.insert(i, os.path.dirname(module))
+ yield
+ sys.path.pop(i)

0 comments on commit 4f68921

Please sign in to comment.