Skip to content

Commit

Permalink
Run predict command from the CLI.
Browse files Browse the repository at this point in the history
Also some bug fixes and cleanup.
  • Loading branch information
lossyrob committed Sep 17, 2018
1 parent bdf1444 commit 6d415ff
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 44 deletions.
48 changes: 43 additions & 5 deletions src/rastervision/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Raster Vision main program"""
import sys
from tempfile import TemporaryDirectory

import click

Expand All @@ -14,14 +15,13 @@ def print_error(msg):
@click.group()
@click.option(
'--profile', '-p', help='Sets the configuration profile name to use.')
@click.option('--verbose', '-v', is_flag=True)
def main(profile, verbose):
# TODO: implement verbose
def main(profile):
# Initialize configuration
if profile:
rv._registry.initialize_config(profile=profile)


@main.command()
@main.command('run', short_help='Run Raster Vision commands against Experiments.')
@click.argument('runner')
@click.argument('commands', nargs=-1)
@click.option(
Expand Down Expand Up @@ -53,8 +53,11 @@ def main(profile, verbose):
default=False,
help=('Rerun commands, regardless if '
'their output files already exist.'))
@click.option('--rv-branch')
@click.option('--rv-branch', help=('Specifies the branch of the raster vision repo '
'to use for executing commands remotely'))
def run(runner, commands, experiment_module, dry_run, arg, rerun, rv_branch):
"""Run Raster Vision commands from experiments, using the
experiment runner named RUNNER."""
# Validate runner
valid_runners = list(
map(lambda x: x.lower(), rv.ExperimentRunner.list_runners()))
Expand Down Expand Up @@ -114,6 +117,7 @@ def run(runner, commands, experiment_module, dry_run, arg, rerun, rv_branch):
'parameter list takes in a parameter with that key. '
'Multiple args can be supplied'))
def ls(experiment_module, arg):
"""Print out a list of Experiment IDs."""
if experiment_module:
module_to_load = experiment_module
else:
Expand Down Expand Up @@ -141,5 +145,39 @@ def ls(experiment_module, arg):
click.echo('{}'.format(e.id))


@main.command('predict', short_help='Make predictions using a predict package.')
@click.argument('predict_package', type=click.Path(exists=True))
@click.argument('image_uri', type=click.Path(exists=True))
@click.argument('output_uri', type=click.Path(exists=False))
@click.option('--update_stats', '-a', is_flag=True,
help=('Run an analysis on this individual image, as '
'opposed to using any analysis like statistics '
'that exist in the prediction package'))
@click.option('--channel-order',
help='String containing channel_order.' + ' Example: \"2 1 0\"')
def predict(predict_package, image_uri, output_uri, update_stats, channel_order):
"""Make predictions on the image at IMAGE_URI
using PREDICT_PACKAGE and store the
prediciton output at OUTPUT_URI.
"""
if channel_order is not None:
channel_order = [
int(channel_ind) for channel_ind in channel_order.split(' ')
]
with TemporaryDirectory() as tmp_dir:
predict = rv.Predictor(predict_package,
tmp_dir,
update_stats,
channel_order).predict
predict(image_uri, output_uri)

@main.command('run_command', short_help='Run a command from configuration file.')
@click.argument('command_config_uri', type=click.Path(exists=True))
def run_command(command_config_uri):
"""Run a command from a serialized command configuration
at COMMAND_CONFIG_URI.
"""
rv.runner.CommandRunner.run(command_config_uri)

if __name__ == '__main__':
main()
9 changes: 0 additions & 9 deletions src/rastervision/data/raster_source/geotiff_source_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ def save_bundle_files(self, bundle_dir):
.build()
return (new_config, files)

def load_bundle_files(self, bundle_dir):
new_transformers = []
for transformer in self.transformers:
new_transformer = transformer.load_bundle_files(bundle_dir)
new_transformers.append(new_transformer)
return self.to_builder() \
.with_transformers(new_transformers) \
.build()

def for_prediction(self, image_uri):
return self.to_builder() \
.with_uri(image_uri) \
Expand Down
10 changes: 0 additions & 10 deletions src/rastervision/data/raster_source/image_source_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,6 @@ def save_bundle_files(self, bundle_dir):
.build()
return (new_config, files)

def load_bundle_files(self, bundle_dir):
new_transformers = []
for transformer in self.transformers:
new_transformer, t_files = transformer.load_bundle_files(
bundle_dir)
new_transformers.append(new_transformer)
return self.to_builder() \
.with_transformers(new_transformers) \
.build()

def for_prediction(self, image_uri):
return self.to_builder() \
.with_uri(image_uri) \
Expand Down
9 changes: 9 additions & 0 deletions src/rastervision/data/raster_source/raster_source_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def save_bundle_files(self, bundle_dir):
.build()
return (new_config, files)

def load_bundle_files(self, bundle_dir):
new_transformers = []
for transformer in self.transformers:
new_transformer = transformer.load_bundle_files(bundle_dir)
new_transformers.append(new_transformer)
return self.to_builder() \
.with_transformers(new_transformers) \
.build()

@abstractmethod
def create_source(self, tmp_dir):
"""Create the Raster Source for this configuration.
Expand Down
12 changes: 6 additions & 6 deletions src/rastervision/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def __init__(self,
.with_label_store()

if channel_order:
scene_builder = scene_builder.with_channel_order(channel_order)
raster_source = scene_builder.config['raster_source'] \
.to_builder() \
.with_channel_order(channel_order) \
.build()
scene_builder = scene_builder.with_raster_source(raster_source)

self.scene_config = scene_builder.build()

Expand Down Expand Up @@ -78,7 +82,7 @@ def predict(self, image_uri, label_uri=None):
# Analyzers should overwrite files in the tmp_dir
if self.update_stats:
for analyzer in self.analyzers:
analyzer.process([scene])
analyzer.process([scene], self.tmp_dir)

# Reload scene to refresh any new analyzer config
scene = scene_config.create_scene(self.task_config, self.tmp_dir)
Expand All @@ -87,7 +91,3 @@ def predict(self, image_uri, label_uri=None):
if label_uri:
scene.prediction_label_store.save(labels)
return labels


if __name__ == '__main__':
pass
1 change: 1 addition & 0 deletions src/rastervision/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from rastervision.runner.experiment_runner import *
from rastervision.runner.local_experiment_runner import *
from rastervision.runner.aws_batch_experiment_runner import *
from rastervision.runner.command_runner import *
2 changes: 1 addition & 1 deletion src/rastervision/runner/aws_batch_experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def make_command(command_config_uri):
return 'python -m rastervision.runner.command_runner {}'.format(
return 'python -m rastervision run_command {}'.format(
command_config_uri)


Expand Down
22 changes: 9 additions & 13 deletions src/rastervision/runner/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@
from rastervision.protos.command_pb2 import CommandConfig as CommandConfigMsg


@click.command
@click.argument('command_config_uri')
def run(command_config_uri):
with TemporaryDirectory as tmp_dir:
msg = load_json_config(command_config_uri, CommandConfigMsg())
PluginRegistry.get_instance().add_plugins(msg.plugins)
command_config = rv.CommandConfig.from_proto(msg)
command = command_config.create_command(tmp_dir)
command.run(tmp_dir)


if __name__ == '__main__':
run()
class CommandRunner:
@staticmethod
def run(command_config_uri):
with TemporaryDirectory as tmp_dir:
msg = load_json_config(command_config_uri, CommandConfigMsg())
PluginRegistry.get_instance().add_plugins(msg.plugins)
command_config = rv.CommandConfig.from_proto(msg)
command = command_config.create_command(tmp_dir)
command.run(tmp_dir)

0 comments on commit 6d415ff

Please sign in to comment.