Skip to content

Commit

Permalink
Merge pull request #463 from azavea/rde/feature/run-files
Browse files Browse the repository at this point in the history
Allow running an experiment from a file
  • Loading branch information
lossyrob committed Oct 7, 2018
2 parents 1683364 + 48ae0c4 commit 12550e2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
30 changes: 21 additions & 9 deletions src/rastervision/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def main(profile, verbose):
help=('Name of an importable module to look for experiment sets '
'in. If not supplied, experiments will be loaded '
'from __main__'))
@click.option(
'--path',
'-p',
metavar='PATTERN',
help=('Path of file containing ExprimentSet to run.'))
@click.option(
'--dry-run',
'-n',
Expand Down Expand Up @@ -85,13 +90,14 @@ def main(profile, verbose):
default=False,
help=('Rerun commands, regardless if '
'their output files already exist.'))
@click.option('--tempdir', help=('Temporary directory to use for this run.'))
def run(runner, commands, experiment_module, dry_run, skip_file_check, arg,
prefix, methods, filters, rerun):
prefix, methods, path, filters, rerun, tempdir):
"""Run Raster Vision commands from experiments, using the
experiment runner named RUNNER."""
darg = dict(arg)
if 'tmp_dir' in darg:
RVConfig.set_tmp_dir(darg['tmp_dir'])

if tempdir:
RVConfig.set_tmp_dir(tempdir)

# Validate runner
valid_runners = list(
Expand All @@ -104,10 +110,9 @@ def run(runner, commands, experiment_module, dry_run, skip_file_check, arg,

runner = ExperimentRunner.get_runner(runner)

if experiment_module:
module_to_load = experiment_module
else:
module_to_load = '__main__'
if experiment_module and path:
print_error('Must specify only one of experiment_module or path')
sys.exit(1)

if not commands:
commands = rv.ALL_COMMANDS
Expand All @@ -124,7 +129,12 @@ def run(runner, commands, experiment_module, dry_run, skip_file_check, arg,
experiment_method_patterns=methods,
experiment_name_patterns=filters)
try:
experiments = loader.load_from_module(module_to_load)
if experiment_module:
experiments = loader.load_from_module(experiment_module)
elif path:
experiments = loader.load_from_file(path)
else:
experiments = loader.load_from_module('__main__')
except LoaderError as e:
print_error(str(e))
sys.exit(1)
Expand All @@ -133,6 +143,8 @@ def run(runner, commands, experiment_module, dry_run, skip_file_check, arg,
if experiment_module:
print_error(
'No experiments found in {}.'.format(experiment_module))
elif path:
print_error('No experiments found in {}.'.format(path))
else:
print_error('No experiments found.')

Expand Down
23 changes: 23 additions & 0 deletions src/rastervision/experiment/experiment_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os
from importlib import import_module
from fnmatch import fnmatchcase

Expand All @@ -22,6 +23,28 @@ def __init__(self,
self.exp_method_patterns = experiment_method_patterns
self.exp_name_patterns = experiment_name_patterns

self._top_level_dir = os.path.abspath(os.curdir)

def _get_name_from_path(self, path):
"""Gets an importable name from a path.
Note: This code is from the python unittest library
"""
if path == self._top_level_dir:
return '.'
path = os.path.splitext(os.path.normpath(path))[0]

_relpath = os.path.relpath(path, self._top_level_dir)
assert not os.path.isabs(_relpath), 'Path must be within the project'
assert not _relpath.startswith('..'), 'Path must be within the project'

name = _relpath.replace(os.path.sep, '.')
return name

def load_from_file(self, path):
name = self._get_name_from_path(path)
return self.load_from_module(name)

def load_from_module(self, name):
result = []
module = import_module(name)
Expand Down
12 changes: 12 additions & 0 deletions src/tests/experiment/test_experiment_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import os

import rastervision as rv
from rastervision.experiment import ExperimentLoader
Expand Down Expand Up @@ -90,6 +91,17 @@ def test_load_module(self):
e_names,
set(['experiment_1', 'experiment_1_yes', 'experiment_2_yes']))

def test_load_file(self):
path = os.path.abspath(__file__)
args = {'required_param': 'yes', 'dummy': 1}
loader = ExperimentLoader(experiment_args=args)
experiments = loader.load_from_file(path)
self.assertEqual(len(experiments), 3)
e_names = set(map(lambda e: e.id, experiments))
self.assertEqual(
e_names,
set(['experiment_1', 'experiment_1_yes', 'experiment_2_yes']))

def test_filter_module_by_method(self):
name = '*2'
args = {'required_param': 'x'}
Expand Down

0 comments on commit 12550e2

Please sign in to comment.