Skip to content

Commit

Permalink
create the "plugins" field for user-defined plugins
Browse files Browse the repository at this point in the history
Summary:
Extend the interfaces in plugins.py to take additional plugin paths, and plumb
this all the way up. Make sure plugins are resolved relative to the project
root (in runtime.py) and not just blindly assumed to be valid paths. This field
shouldn't be able to change the behavior of existing modules, so if the same
plugin is ever defined twice, raise an error.

The complexity of this diff raises a couple of design issues that we need to
think about:
- The plugins.py interface is getting very complicated. Should we encapsulate
  it in some kind of config object? Should it take a `Runtime`? (This would make
  it difficult to unit test.)
- Allowing users to run peru from directories other than the project root is
  creating complexity throughout the code. The kind of path massaging that we
  have to do here in runtime.py is very easy to get wrong, and very easy to
  miss in testing. We really don't want to burden every new feature with a test
  to make sure it handles funny dirs. Is there a way we could make this kind of
  testing automatic? Should we use chdir magic to make paths Just Work? (That
  would make resolving paths in peru.yaml easier, at the expense of resolving
  paths given on the command line.)

Test Plan: Lots of new unit tests and an integration test.

Reviewers: sean

Differential Revision: https://phabricator.buildinspace.com/D27
  • Loading branch information
oconnor663 committed Jul 24, 2014
1 parent 3a293d5 commit 275e8fc
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 61 deletions.
19 changes: 18 additions & 1 deletion peru/parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import os
import re
import yaml
Expand All @@ -12,6 +13,10 @@ class ParserError(PrintableError):
pass


ParseResult = collections.namedtuple(
"ParseResult", ["scope", "local_module", "plugin_paths"])


def parse_file(file_path, **local_module_kwargs):
project_root = os.path.dirname(file_path)
with open(file_path) as f:
Expand All @@ -30,10 +35,22 @@ def parse_string(yaml_str, project_root='.', **local_module_kwargs):

def _parse_toplevel(blob, **local_module_kwargs):
scope = {}
plugin_paths = _extract_plugin_paths(blob)
_extract_named_rules(blob, scope)
_extract_remote_modules(blob, scope)
local_module = _build_local_module(blob, **local_module_kwargs)
return (scope, local_module)
return ParseResult(scope, local_module, plugin_paths)


def _extract_plugin_paths(blob):
raw_value = blob.pop("plugins", [])
if isinstance(raw_value, str):
plugin_paths = (raw_value,)
elif isinstance(raw_value, list):
plugin_paths = tuple(raw_value)
else:
raise ParserError("'plugins' field must be a string or a list.")
return plugin_paths


def _build_local_module(blob, **local_module_kwargs):
Expand Down
74 changes: 52 additions & 22 deletions peru/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@


def plugin_fetch(cwd, plugins_cache_root, type, dest, plugin_fields, *,
capture_output=False, stderr_to_stdout=False):
capture_output=False, stderr_to_stdout=False,
plugin_roots=()):
cache_path = _plugin_cache_path(plugins_cache_root, type)
command = _plugin_command(type, 'fetch', plugin_fields, dest, cache_path)
command = _plugin_command(type, 'fetch', plugin_roots, plugin_fields,
dest, cache_path)

kwargs = {"stderr": subprocess.STDOUT} if stderr_to_stdout else {}
if capture_output:
Expand All @@ -28,9 +30,11 @@ def plugin_fetch(cwd, plugins_cache_root, type, dest, plugin_fields, *,
subprocess.check_call(command, cwd=cwd, **kwargs)


def plugin_get_reup_fields(cwd, plugins_cache_root, type, plugin_fields):
def plugin_get_reup_fields(cwd, plugins_cache_root, type, plugin_fields, *,
plugin_roots=()):
cache_path = _plugin_cache_path(plugins_cache_root, type)
command = _plugin_command(type, 'reup', plugin_fields, cache_path)
command = _plugin_command(type, 'reup', plugin_roots, plugin_fields,
cache_path)
output = subprocess.check_output(command, cwd=cwd).decode('utf8')
new_fields = yaml.safe_load(output) or {}
for key, val in new_fields.items():
Expand All @@ -39,8 +43,8 @@ def plugin_get_reup_fields(cwd, plugins_cache_root, type, plugin_fields):
return new_fields


def _plugin_command(type, subcommand, plugin_fields, *args):
path = _plugin_exe_path(type, subcommand)
def _plugin_command(type, subcommand, plugin_roots, plugin_fields, *args):
path = _plugin_exe_path(type, subcommand, plugin_roots)

assert os.access(path, os.X_OK), type + " plugin isn't executable."
assert "--" not in plugin_fields, "-- is not a valid field name"
Expand All @@ -61,31 +65,57 @@ def _plugin_cache_path(plugins_cache_root, type):
return plugin_cache


def _plugin_exe_path(type, subcommand):
# Scan for a corresponding script dir.
root = os.path.join(PLUGINS_DIR, type)
if not os.path.isdir(root):
raise PrintableError(
'no root directory found for plugin `{}`'.format(type))
def _plugin_exe_path(type, subcommand, plugin_roots):
root = _find_plugin(type, plugin_roots)

# Scan for files in the script dir.
# The subcommand is used as a prefix, not an exact match, to support
# extensions.
# To support certain platforms, an extension is necessary for execution as
# a subprocess.
# Most notably, this is required to support Windows.
# Scan for files in the script dir. The subcommand is used as a prefix, not
# an exact match, to support extensions. To support certain platforms, an
# extension is necessary for execution as a subprocess. Most notably, this
# is required to support Windows.
matches = [match for match in os.listdir(root) if
match.startswith(subcommand) and
os.path.isfile(os.path.join(root, match))]

# Ensure there is only one match.
# It is possible for multiple files to share the subcommand prefix.
# Ensure there is only one match. It is possible for multiple files to
# share the subcommand prefix.
if not(matches):
raise PrintableError(
raise PluginCommandMissingError(
'no candidate for command `{}`'.format(subcommand))
if len(matches) > 1:
# Barf if there is more than one candidate.
raise PrintableError(
raise MultiplePluginCommandsError(
'more than one candidate for command `{}`'.format(subcommand))

return os.path.join(root, matches[0])


def _find_plugin(type, plugin_roots):
roots = [PLUGINS_DIR] + list(plugin_roots)
options = [os.path.join(root, type) for root in roots]
matches = [option for option in options if os.path.isdir(option)]
if not matches:
raise PluginMissingError(
'No plugin "{}" found at any of the following roots:\n'
.format(type) +
'\n'.join(roots))
if len(matches) > 1:
raise MultiplePluginsError(
'Multiple plugins found of type "{}":\n'.format(type) +
'\n'.join(matches))
return matches[0]


class PluginMissingError(PrintableError):
pass


class MultiplePluginsError(PrintableError):
pass


class PluginCommandMissingError(PrintableError):
pass


class MultiplePluginCommandsError(PrintableError):
pass
5 changes: 3 additions & 2 deletions peru/remote_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def get_tree(self, runtime):
return runtime.cache.keyval[key]
with runtime.tmp_dir() as tmp_dir:
plugin_fetch(runtime.root, runtime.cache.plugins_root,
self.type, tmp_dir, self.plugin_fields)
self.type, tmp_dir, self.plugin_fields,
plugin_roots=runtime.plugin_roots)
base_tree = runtime.cache.import_tree(tmp_dir)
tree = resolver.merge_import_trees(
runtime, self.imports, base_tree)
Expand All @@ -46,7 +47,7 @@ def reup(self, runtime):
print("reup", self.name)
reup_fields = plugin_get_reup_fields(
runtime.root, runtime.cache.plugins_root, self.type,
self.plugin_fields)
self.plugin_fields, plugin_roots=runtime.plugin_roots)
for field, val in reup_fields.items():
if (field not in self.plugin_fields or
val != self.plugin_fields[field]):
Expand Down
6 changes: 5 additions & 1 deletion peru/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def __init__(self, args, env):
'PERU_DIR', os.path.join(self.root, '.peru'))
compat.makedirs(self.peru_dir)

self.scope, self.local_module = parser.parse_file(
parse_result = parser.parse_file(
self.peru_file, peru_dir=self.peru_dir)
self.scope = parse_result.scope
self.local_module = parse_result.local_module
self.plugin_roots = tuple(os.path.join(self.root, path)
for path in parse_result.plugin_paths)

cache_dir = env.get('PERU_CACHE', os.path.join(self.peru_dir, 'cache'))
self.cache = cache.Cache(cache_dir)
Expand Down
25 changes: 24 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import os
import shutil
import sys
from textwrap import dedent
import unittest
Expand All @@ -11,7 +12,8 @@

import shared

peru_bin = os.path.join(os.path.dirname(__file__), "..", "..", "peru.sh")
PERU_MODULE_ROOT = os.path.abspath(
os.path.join(os.path.dirname(peru.__file__)))


def run_peru_command(args, test_dir, peru_dir, *, env_vars=None,
Expand Down Expand Up @@ -208,6 +210,27 @@ def test_local_build(self):
"fi": "feefee",
})

def test_local_plugins(self):
cp_plugin_path = os.path.join(
PERU_MODULE_ROOT, 'resources', 'plugins', 'cp')
shutil.copytree(cp_plugin_path,
os.path.join(self.test_dir, 'myplugins', 'newfangled'))
# Grab the contents now so that we can match it later.
# TODO: Rethink how these tests are structured.
expected_content = shared.read_dir(self.test_dir)
expected_content['foo'] = 'bar'

self.write_peru_yaml('''\
imports:
foo: ./
plugins: myplugins/
newfangled module foo:
path: {}
''')
self.do_integration_test(['sync'], expected_content)

def test_alternate_cache(self):
self.write_peru_yaml("""\
cp module foo:
Expand Down
71 changes: 49 additions & 22 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,52 @@
class ParserTest(unittest.TestCase):

def test_parse_empty_file(self):
scope, local_module = parse_string('')
self.assertDictEqual(scope, {})
self.assertDictEqual(local_module.imports, {})
self.assertEqual(local_module.default_rule, None)
self.assertEqual(local_module.root, '.')
result = parse_string('')
self.assertDictEqual(result.scope, {})
self.assertDictEqual(result.local_module.imports, {})
self.assertEqual(result.local_module.default_rule, None)
self.assertEqual(result.local_module.root, '.')
self.assertTupleEqual(result.plugin_paths, ())

def test_parse_with_project_root(self):
scope, local_module = parse_string('', project_root='foo/bar')
self.assertEqual(local_module.root, 'foo/bar')
result = parse_string('', project_root='foo/bar')
self.assertEqual(result.local_module.root, 'foo/bar')

def test_parse_with_plugin_paths(self):
result = parse_string(dedent('''\
plugins: foo
'''))
self.assertTupleEqual(('foo',), result.plugin_paths)

def test_parse_with_list_of_plugin_paths(self):
result = parse_string(dedent('''\
plugins:
- foo
- bar
'''))
self.assertTupleEqual(('foo', 'bar'), result.plugin_paths)

def test_parse_with_bad_plugin_paths(self):
wrong_type = dedent('''\
plugins: 5
''')
with self.assertRaises(ParserError):
parse_string(wrong_type)
empty_val = dedent('''\
plugins:
''')
with self.assertRaises(ParserError):
parse_string(empty_val)

def test_parse_rule(self):
input = dedent("""\
rule foo:
build: echo hi
export: out/
""")
scope, local_module = parse_string(input)
self.assertIn("foo", scope)
rule = scope["foo"]
result = parse_string(input)
self.assertIn("foo", result.scope)
rule = result.scope["foo"]
self.assertIsInstance(rule, Rule)
self.assertEqual(rule.name, "foo")
self.assertEqual(rule.build_command, "echo hi")
Expand All @@ -42,9 +69,9 @@ def test_parse_module(self):
wham: bam/
thank: you/maam
""")
scope, local_module = parse_string(input)
self.assertIn("foo", scope)
module = scope["foo"]
result = parse_string(input)
self.assertIn("foo", result.scope)
module = result.scope["foo"]
self.assertIsInstance(module, RemoteModule)
self.assertEqual(module.name, "foo")
self.assertEqual(module.type, "sometype")
Expand All @@ -61,9 +88,9 @@ def test_parse_module_default_rule(self):
build: foo
export: bar
""")
scope, local_module = parse_string(input)
self.assertIn("bar", scope)
module = scope["bar"]
result = parse_string(input)
self.assertIn("bar", result.scope)
module = result.scope["bar"]
self.assertIsInstance(module, RemoteModule)
self.assertIsInstance(module.default_rule, Rule)
self.assertEqual(module.default_rule.build_command, "foo")
Expand All @@ -74,17 +101,17 @@ def test_parse_toplevel_imports(self):
imports:
foo: bar/
""")
scope, local_module = parse_string(input)
self.assertDictEqual(scope, {})
self.assertDictEqual(local_module.imports, {"foo": "bar/"})
result = parse_string(input)
self.assertDictEqual(result.scope, {})
self.assertDictEqual(result.local_module.imports, {"foo": "bar/"})

def test_parse_empty_imports(self):
input = dedent('''\
imports:
''')
scope, local_module = parse_string(input)
self.assertDictEqual(scope, {})
self.assertDictEqual(local_module.imports, {})
result = parse_string(input)
self.assertDictEqual(result.scope, {})
self.assertDictEqual(result.local_module.imports, {})

def test_bad_toplevel_field_throw(self):
with self.assertRaises(ParserError):
Expand Down
Loading

0 comments on commit 275e8fc

Please sign in to comment.