Skip to content

Commit

Permalink
Merge pull request #1082 from azavea/lf/backport-predict
Browse files Browse the repository at this point in the history
[BACKPORT] Add support for vector output to predict command
  • Loading branch information
lewfish committed Jan 27, 2021
2 parents c63b2ce + 9555799 commit e0ad52a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 14 deletions.
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
CHANGELOG
=========

Raster Vision 0.12.1
---------------------

Bug Fixes
^^^^^^^^^^^

* Add support for vector output to predict command `#980 <https://github.com/azavea/raster-vision/pull/980>`_ This also enables use of model bundles in the Model Zoo without having access to Azavea's S3 buckets.

Raster Vision 0.12
-------------------

Expand Down
18 changes: 10 additions & 8 deletions docs/cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,17 @@ Use ``predict`` to make predictions on new imagery given a :ref:`model bundle <m
> rastervision predict --help
Usage: rastervision predict [OPTIONS] MODEL_BUNDLE IMAGE_URI OUTPUT_URI
Usage: rastervision predict [OPTIONS] MODEL_BUNDLE IMAGE_URI LABEL_URI
Make predictions on the images at IMAGE_URI using MODEL_BUNDLE and store
the prediction output at OUTPUT_URI.
the prediction output at LABEL_URI.
Options:
-a, --update-stats Run an analysis on this individual image, as opposed
to using any analysis like statistics that exist in
the prediction package
--channel-order TEXT List of indices comprising channel_order. Example: 2 1
0
--help Show this message and exit.
--vector-label-uri TEXT URI to save vectorized labels for semantic
segmentation model bundles that support it
-a, --update-stats Run an analysis on this individual image, as
opposed to using any analysis like statistics that
exist in the prediction package
--channel-order TEXT List of indices comprising channel_order. Example:
2 1 0
--help Show this message and exit.
16 changes: 12 additions & 4 deletions rastervision_core/rastervision/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ def parser_process(value, state):
'predict', short_help='Use a model bundle to predict on new images.')
@click.argument('model_bundle')
@click.argument('image_uri')
@click.argument('output_uri')
@click.argument('label_uri')
@click.option(
'--vector-label-uri',
type=str,
help=
('URI to save vectorized labels for semantic segmentation model bundles that support '
'it'))
@click.option(
'--update-stats',
'-a',
Expand All @@ -51,9 +57,10 @@ def parser_process(value, state):
'--channel-order',
cls=OptionEatAll,
help='List of indices comprising channel_order. Example: 2 1 0')
def predict(model_bundle, image_uri, output_uri, update_stats, channel_order):
def predict(model_bundle, image_uri, label_uri, vector_label_uri, update_stats,
channel_order):
"""Make predictions on the images at IMAGE_URI
using MODEL_BUNDLE and store the prediction output at OUTPUT_URI.
using MODEL_BUNDLE and store the prediction output at LABEL_URI.
"""
if channel_order is not None:
channel_order = [
Expand All @@ -63,4 +70,5 @@ def predict(model_bundle, image_uri, output_uri, update_stats, channel_order):
with rv_config.get_tmp_dir() as tmp_dir:
predictor = Predictor(model_bundle, tmp_dir, update_stats,
channel_order)
predictor.predict([image_uri], output_uri)
predictor.predict(
[image_uri], label_uri, vector_label_uri=vector_label_uri)
20 changes: 18 additions & 2 deletions rastervision_core/rastervision/core/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,34 @@ def __init__(self,
if channel_order is not None:
self.scene.raster_source.channel_order = channel_order

def predict(self, image_uris, label_uri):
def predict(self, image_uris, label_uri, vector_label_uri=None):
"""Generate predictions for the given image.
Args:
image_uris: URIs of the images to make predictions against.
This can be any type of URI readable by Raster Vision
FileSystems.
label_uri: URI to save labels off into.
label_uri: URI to save labels off into
vector_label_uri: URI to save vectorized labels for semantic segmentation
model bundles that support it
"""
try:
self.scene.raster_source.uris = image_uris
self.scene.label_store.uri = label_uri
if (hasattr(self.scene.label_store, 'vector_output')
and self.scene.label_store.vector_output):
if vector_label_uri:
for vo in self.scene.label_store.vector_output:
vo.uri = join(
vector_label_uri, '{}-{}.json'.format(
vo.class_id, vo.get_mode()))
else:
self.scene.label_store.vector_output = []
elif vector_label_uri:
log.warn(
'vector_label_uri was supplied but this model bundle does not '
'generate vector labels.')

if self.update_stats:
self.pipeline.analyze()
self.pipeline.predict()
Expand Down

0 comments on commit e0ad52a

Please sign in to comment.