Skip to content

Commit

Permalink
Overhaul decorators to use func attrs vs registries
Browse files Browse the repository at this point in the history
  • Loading branch information
bitprophet committed Sep 20, 2011
1 parent 28bcb92 commit 4f68921
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 73 deletions.
35 changes: 12 additions & 23 deletions fabric/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,32 +136,29 @@ def decorated(*args, **kwargs):
return serial(decorated)


_serial = set()
def serial(func):
"""
Forces the wrapped function to always run sequentially, never in parallel.
This decorator takes precedence over the global value of
:ref:`env.parallel <env-parallel>`.
This decorator takes precedence over the global value of :ref:`env.parallel
<env-parallel>`. However, if a task is decorated with both
`~fabric.decorators.serial` *and* `~fabric.decorators.parallel`,
`~fabric.decorators.parallel` wins.
.. versionadded:: 1.3
"""
# Register
_serial.add(func.func_name)
_parallel.discard(func.func_name)
if not getattr(func, 'parallel', False):
func.serial = True
return func

def is_serial(func):
return func.func_name in _serial


_parallel = set()
def parallel(pool_size=None):
"""
Forces the wrapped function to run in parallel, instead of sequentially.
This decorator takes precedence over the global value of
:ref:`env.parallel <env-parallel>`.
This decorator takes precedence over the global value of :ref:`env.parallel
<env-parallel>`. It also takes precedence over `~fabric.decorators.serial`
if a task is decorated with both.
.. versionadded:: 1.3
"""
Expand All @@ -171,11 +168,9 @@ def inner(*args, **kwargs):
# Required for Paramiko/PyCrypto to be happy in multiprocessing
Random.atfork()
return func(*args, **kwargs)
# Register
_parallel.add(func.func_name)
_serial.discard(func.func_name)
# Tell function what its pool size is
inner._pool_size = pool_size
inner.parallel = True
inner.serial = False
inner.pool_size = pool_size
return inner

# Allow non-factory-style decorator use (@decorator vs @decorator())
Expand All @@ -184,12 +179,6 @@ def inner(*args, **kwargs):

return real_decorator

def is_parallel(func):
return func.func_name in _parallel

def needs_multiprocessing():
return _parallel != set()


def with_settings(**kw_settings):
"""
Expand Down
25 changes: 14 additions & 11 deletions fabric/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from fabric.state import commands, connections, env_options
from fabric.tasks import Task
from fabric.utils import abort, indent
from fabric.decorators import is_parallel, is_serial, needs_multiprocessing
from job_queue import JobQueue

# One-time calculation of "all internal callables" to avoid doing this on every
Expand Down Expand Up @@ -625,21 +624,17 @@ def update_output_levels(show, hide):

def requires_parallel(task):
"""
Returns True if given ``task`` can and should be run in parallel mode.
Returns True if given ``task`` should be run in parallel mode.
Specifically:
* The ``multiprocessing`` module is loaded, and:
* It's been explicitly marked with ``@parallel``, or:
* It's *not* been explicitly marked with ``@serial`` *and* the global
parallel option (``env.parallel``) is set to ``True``.
"""
return (
('multiprocessing' in sys.modules)
and (
(state.env.parallel and not is_serial(task))
or is_parallel(task)
)
(state.env.parallel and not getattr(task, 'serial', False))
or getattr(task, 'parallel', False)
)


Expand All @@ -656,7 +651,7 @@ def _get_pool_size(task, hosts):
# change)
default_pool_size = state.env.pool_size or len(hosts)
# Allow per-task override
pool_size = getattr(task, '_pool_size', default_pool_size)
pool_size = getattr(task, 'pool_size', default_pool_size)
# But ensure it's never larger than the number of hosts
pool_size = min((pool_size, len(hosts)))
# Inform user of final pool size for this task
Expand All @@ -666,6 +661,13 @@ def _get_pool_size(task, hosts):
return pool_size


def _parallel_tasks(commands_to_run):
return any(map(
lambda x: requires_parallel(crawl(x[0], state.commands)),
commands_to_run
))


def main():
"""
Main command-line execution loop.
Expand Down Expand Up @@ -789,7 +791,7 @@ def main():
print("Commands to run: %s" % names)

# Import multiprocessing if needed, erroring out usefully if it can't.
if state.env.parallel or needs_multiprocessing():
if state.env.parallel or _parallel_tasks(commands_to_run):
try:
import multiprocessing
except ImportError, e:
Expand Down Expand Up @@ -826,7 +828,8 @@ def main():
print("[%s] Executing task '%s'" % (host, name))

# Handle parallel execution
if requires_parallel(task):
have_multiprocessing = 'multiprocessing' in sys.modules
if requires_parallel(task) and have_multiprocessing:
# Grab appropriate callable (func or instance method)
to_call = task
if hasattr(task, 'run') and callable(task.run):
Expand Down
74 changes: 44 additions & 30 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import with_statement

import random

from nose.tools import eq_, ok_, assert_true, assert_false, assert_equal
import fudge
from fudge import Fake, with_fakes
from fudge import Fake, with_fakes, patched_context

from fabric import decorators, tasks
from fabric.state import env
import fabric # for patching fabric.state.xxx
from fabric.main import _parallel_tasks, requires_parallel


#
Expand Down Expand Up @@ -138,7 +142,6 @@ def single_run():
pass

def test_runs_once():
assert_true(decorators.is_serial(single_run))
assert_false(hasattr(single_run, 'return_value'))
single_run()
assert_true(hasattr(single_run, 'return_value'))
Expand All @@ -150,6 +153,7 @@ def test_runs_once():
# @serial / @parallel
#


@decorators.serial
def serial():
pass
Expand All @@ -159,43 +163,53 @@ def serial():
def serial2():
pass

def test_serial():
assert_true(decorators.is_serial(serial))
assert_false(decorators.is_parallel(serial))
serial()

assert_true(decorators.is_serial(serial2))
assert_false(decorators.is_parallel(serial2))
serial2()


@decorators.parallel
def parallel():
@decorators.serial
def serial3():
pass

@decorators.parallel
@decorators.serial
def parallel2():
def parallel():
pass

@decorators.parallel(pool_size=20)
def parallel3():
def parallel2():
pass

def test_parallel():
assert_true(decorators.is_parallel(parallel))
assert_false(decorators.is_serial(parallel))
parallel()

assert_true(decorators.is_parallel(parallel2))
assert_false(decorators.is_serial(parallel2))
parallel2()

assert_true(decorators.is_parallel(parallel))
assert_false(decorators.is_serial(parallel))
assert_equal(parallel3._pool_size, 20)
assert_equal(getattr(parallel3, '_pool_size'), 20)

fake_tasks = {
'serial': serial,
'serial2': serial2,
'serial3': serial3,
'parallel': parallel,
'parallel2': parallel2,
}

def parallel_task_helper(actual_tasks, expected):
commands_to_run = map(lambda x: [x], actual_tasks)
with patched_context(fabric.state, 'commands', fake_tasks):
eq_(_parallel_tasks(commands_to_run), expected)

def test_parallel_tasks():
for desc, task_names, expected in (
("One @serial-decorated task == no parallelism",
['serial'], False),
("One @parallel-decorated task == parallelism",
['parallel'], True),
("One @parallel-decorated and one @serial-decorated task == paralellism",
['parallel', 'serial'], True),
("Tasks decorated with both @serial and @parallel count as @parallel",
['serial2', 'serial3'], True)
):
parallel_task_helper.description = desc
yield parallel_task_helper, task_names, expected
del parallel_task_helper.description

def test_parallel_wins_vs_serial():
"""
@parallel takes precedence over @serial when both are used on one task
"""
ok_(requires_parallel(serial2))
ok_(requires_parallel(serial3))


#
Expand Down
11 changes: 2 additions & 9 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from fabric.state import _AttributeDict
from fabric.tasks import Task

from utils import mock_streams, patched_env, eq_, FabricTest, fabfile
from utils import (mock_streams, patched_env, eq_, FabricTest, fabfile,
path_prefix)


# Stupid load_fabfile wrapper to hide newly added return value.
Expand Down Expand Up @@ -315,14 +316,6 @@ def test_load_fabfile_should_not_remove_real_path_elements():
# Namespacing and new-style tasks
#

@contextmanager
def path_prefix(module):
i = 0
sys.path.insert(i, os.path.dirname(module))
yield
sys.path.pop(i)


class TestTaskAliases(FabricTest):
def test_flat_alias(self):
f = fabfile("flat_alias.py")
Expand Down
8 changes: 8 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,11 @@ def wrapper(func):

def fabfile(name):
return os.path.join(os.path.dirname(__file__), 'support', name)


@contextmanager
def path_prefix(module):
i = 0
sys.path.insert(i, os.path.dirname(module))
yield
sys.path.pop(i)

0 comments on commit 4f68921

Please sign in to comment.