Skip to content

Commit

Permalink
consolidate prediction options into PredictOptions for all tasks (#2055)
Browse files Browse the repository at this point in the history
Changes:
- Move `predict_chip_sz` and `predict_batch_sz` to `PredictOptions`. Also add `stride`.
- Make `predict_options` a field in `RVPipelineConfig`. Previously, this was only defined in the SS and OD subclasses.
- Make `Backend.predict_scene()` take `PredictOptions` instead of `chip_sz`, `stride` etc.
- Move default `stride`, `crop_sz` initialization to pydantic validators.
- Move OD prediction post-processing to the OD PyTorch `Backend`.
- Remove unused SS post processing.
- Remove support for `RASTERVISION_PREDICT_BATCH_SIZE` `RVConfig` param that was used by `Learner.predict_dataset()`.
- Update usage in examples.
- Update unit and integration tests.
  • Loading branch information
AdeelH committed Feb 7, 2024
1 parent 6ad3bdd commit 9fd4dd5
Show file tree
Hide file tree
Showing 32 changed files with 244 additions and 240 deletions.
8 changes: 4 additions & 4 deletions integration_tests/chip_classification/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from os.path import join, dirname, basename

from rastervision.core.rv_pipeline import (ChipClassificationConfig,
ChipOptions, WindowSamplingConfig,
WindowSamplingMethod)
from rastervision.core.rv_pipeline import (
ChipClassificationConfig, ChipOptions, PredictOptions,
WindowSamplingConfig, WindowSamplingMethod)
from rastervision.core.data import (
ClassConfig, ChipClassificationLabelSourceConfig,
GeoJSONVectorSourceConfig, RasterioSourceConfig, StatsTransformerConfig,
Expand Down Expand Up @@ -102,6 +102,6 @@ def make_scene(img_path, label_path):
dataset=scene_dataset,
backend=backend,
chip_options=chip_options,
predict_chip_sz=chip_sz)
predict_options=PredictOptions(chip_sz=chip_sz))

return config
3 changes: 1 addition & 2 deletions integration_tests/object_detection/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,11 @@ def make_scene(scene_id, img_path, label_path):
run_tensorboard=False)

predict_options = ObjectDetectionPredictOptions(
merge_thresh=0.1, score_thresh=0.5)
chip_sz=chip_sz, merge_thresh=0.1, score_thresh=0.5)

return ObjectDetectionConfig(
root_uri=root_uri,
dataset=scene_dataset,
backend=backend,
predict_chip_sz=chip_sz,
chip_options=chip_options,
predict_options=predict_options)
6 changes: 4 additions & 2 deletions integration_tests/semantic_segmentation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
RGBClassTransformerConfig)
from rastervision.core.rv_pipeline import (
SemanticSegmentationChipOptions, SemanticSegmentationConfig,
WindowSamplingConfig, WindowSamplingMethod)
SemanticSegmentationPredictOptions, WindowSamplingConfig,
WindowSamplingMethod)
from rastervision.pytorch_backend import PyTorchSemanticSegmentationConfig
from rastervision.pytorch_learner import (
Backbone, SolverConfig, SemanticSegmentationModelConfig,
Expand Down Expand Up @@ -101,10 +102,11 @@ def make_scene(id, img_path, label_path):
solver=solver,
log_tensorboard=False,
run_tensorboard=False)
predict_options = SemanticSegmentationPredictOptions(chip_sz=chip_sz)

return SemanticSegmentationConfig(
root_uri=root_uri,
dataset=scene_dataset,
backend=backend,
chip_options=chip_options,
predict_chip_sz=chip_sz)
predict_options=predict_options)
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def register_plugin(registry):
registry.set_plugin_version('rastervision.core', 11)
registry.set_plugin_version('rastervision.core', 12)
from rastervision.core.cli import predict, predict_scene
registry.add_plugin_command(predict)
registry.add_plugin_command(predict_scene)
Expand Down
9 changes: 5 additions & 4 deletions rastervision_core/rastervision/core/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
if TYPE_CHECKING:
from rastervision.core.data_sample import DataSample
from rastervision.core.data import DatasetConfig, Labels, Scene
from rastervision.core.rv_pipeline import ChipOptions
from rastervision.core.rv_pipeline import ChipOptions, PredictOptions


class SampleWriter(AbstractContextManager):
Expand Down Expand Up @@ -48,12 +48,13 @@ def load_model(self, uri: Optional[str] = None):
"""

@abstractmethod
def predict_scene(self, scene: 'Scene', chip_sz: int,
stride: int) -> 'Labels':
def predict_scene(self, scene: 'Scene',
predict_options: 'PredictOptions') -> 'Labels':
"""Return predictions for an entire scene using the model.
Args:
scene (Scene): Scene to run inference on.
scene: Scene to run inference on.
predict_options: Prediction options.
Return:
Labels object containing predictions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@
ChipOptions.__name__,
WindowSamplingConfig.__name__,
WindowSamplingMethod.__name__,
PredictOptions.__name__,
]
Original file line number Diff line number Diff line change
@@ -1,33 +1,5 @@
from typing import TYPE_CHECKING
import logging

from rastervision.core.rv_pipeline import RVPipeline
from rastervision.core.data.label import ObjectDetectionLabels

if TYPE_CHECKING:
from rastervision.core.data import Labels, Scene

log = logging.getLogger(__name__)


class ObjectDetection(RVPipeline):
def predict_scene(self, scene: 'Scene') -> 'Labels':
if self.backend is None:
self.build_backend()

# Use strided windowing to ensure that each object is fully visible (ie. not
# cut off) within some window. This means prediction takes 4x longer for object
# detection :(
chip_sz = self.config.predict_chip_sz
stride = chip_sz // 2
labels = self.backend.predict_scene(
scene, chip_sz=chip_sz, stride=stride)
labels = self.post_process_predictions(labels, scene)
return labels

def post_process_predictions(self, labels: ObjectDetectionLabels,
scene: 'Scene') -> ObjectDetectionLabels:
return ObjectDetectionLabels.prune_duplicates(
labels,
score_thresh=self.config.predict_options.score_thresh,
merge_thresh=self.config.predict_options.merge_thresh)
pass
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from rastervision.pipeline.config import register_config, Field
from rastervision.pipeline.config import Field, register_config, validator
from rastervision.core.rv_pipeline import (
ChipOptions, RVPipelineConfig, PredictOptions, WindowSamplingConfig)
from rastervision.core.data.label_store import ObjectDetectionGeoJSONStoreConfig
Expand Down Expand Up @@ -42,6 +42,10 @@ class ObjectDetectionChipOptions(ChipOptions):

@register_config('object_detection_predict_options')
class ObjectDetectionPredictOptions(PredictOptions):
stride: Optional[int] = Field(
None,
description='Stride of the sliding window for generating chips. '
'Defaults to half of ``chip_sz``.')
merge_thresh: float = Field(
0.5,
description=
Expand All @@ -55,15 +59,20 @@ class ObjectDetectionPredictOptions(PredictOptions):
('Predicted boxes are only output if their score is above score_thresh.'
))

@validator('stride', always=True)
def validate_stride(cls, v: Optional[int], values: dict) -> dict:
if v is None:
chip_sz: int = values['chip_sz']
return chip_sz // 2
return v


@register_config('object_detection')
class ObjectDetectionConfig(RVPipelineConfig):
"""Configure an :class:`.ObjectDetection` pipeline."""

chip_options: Optional[ObjectDetectionChipOptions] = Field(
None, description='Config for chip stage.')
predict_options: Optional[
ObjectDetectionPredictOptions] = ObjectDetectionPredictOptions()
chip_options: Optional[ObjectDetectionChipOptions]
predict_options: Optional[ObjectDetectionPredictOptions]

def build(self, tmp_dir):
from rastervision.core.rv_pipeline.object_detection import ObjectDetection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,8 @@ def predict(self, split_ind=0, num_splits=1):
def predict_scene(self, scene: Scene) -> Labels:
if self.backend is None:
self.build_backend()
chip_sz = self.config.predict_chip_sz
stride = chip_sz
labels = self.backend.predict_scene(
scene, chip_sz=chip_sz, stride=stride)
scene, predict_options=self.config.predict_options)
labels = self.post_process_predictions(labels, scene)
return labels

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,51 @@
from rastervision.core.backend import BackendConfig
from rastervision.core.evaluation import EvaluatorConfig
from rastervision.core.analyzer import AnalyzerConfig
from rastervision.core.rv_pipeline.chip_options import (ChipOptions,
WindowSamplingConfig)
from rastervision.pipeline.config import (Config, Field, register_config)
from rastervision.core.rv_pipeline.chip_options import ChipOptions
from rastervision.pipeline.config import (Config, Field, register_config,
validator)

if TYPE_CHECKING:
from rastervision.core.backend.backend import Backend # noqa


@register_config('predict_options')
class PredictOptions(Config):
# TODO: predict_chip_sz and predict_batch_sz should probably be moved here
pass
chip_sz: int = Field(
300, description='Size of predictions chips in pixels.')
stride: Optional[int] = Field(
None,
description='Stride of the sliding window for generating chips.'
'Defaults to ``chip_sz``.')
batch_sz: int = Field(
8, description='Batch size to use during prediction.')

@validator('stride', always=True)
def validate_stride(cls, v: Optional[int], values: dict) -> dict:
if v is None:
chip_sz: int = values['chip_sz']
return chip_sz
return v


def rv_pipeline_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version == 10:
train_chip_sz = cfg_dict.pop('train_chip_sz', 300)
nodata_threshold = cfg_dict.pop('chip_nodata_threshold')
if 'chip_options' not in cfg_dict:
cfg_dict['chip_options'] = ChipOptions(
sampling=WindowSamplingConfig(size=train_chip_sz),
nodata_threshold=nodata_threshold)
else:
cfg_dict['chip_options']['sampling']['size'] = train_chip_sz
cfg_dict['chip_options']['nodata_threshold'] = nodata_threshold
nodata_threshold = cfg_dict.pop('chip_nodata_threshold', 1.)
chip_options: dict = cfg_dict.get('chip_options', {})
method = chip_options.pop('method', 'sliding')
if method != 'sliding':
method = 'random'
chip_options['sampling'] = dict(size=train_chip_sz, method=method)
chip_options['nodata_threshold'] = nodata_threshold
cfg_dict['chip_options'] = chip_options
elif version == 11:
predict_chip_sz = cfg_dict.pop('predict_chip_sz', 300)
predict_batch_sz = cfg_dict.pop('predict_batch_sz', 8)
predict_options = cfg_dict.get('predict_options', {})
predict_options['chip_sz'] = predict_chip_sz
predict_options['batch_sz'] = predict_batch_sz
cfg_dict['predict_options'] = predict_options
return cfg_dict


Expand All @@ -57,13 +77,6 @@ class RVPipelineConfig(PipelineConfig):
('Analyzers to run during analyzer command. A StatsAnalyzer will be added '
'automatically if any scenes have a RasterTransformer.'))

chip_options: Optional[ChipOptions] = Field(
None, description='Config for chip stage.')
predict_chip_sz: int = Field(
300, description='Size of predictions chips in pixels.')
predict_batch_sz: int = Field(
8, description='Batch size to use during prediction.')

analyze_uri: Optional[str] = Field(
None,
description=
Expand Down Expand Up @@ -91,6 +104,11 @@ class RVPipelineConfig(PipelineConfig):
description='If provided, the model will be loaded from this bundle '
'for the train stage. Useful for fine-tuning.')

chip_options: Optional[ChipOptions] = Field(
None, description='Config for chip stage.')
predict_options: Optional[PredictOptions] = Field(
None, description='Config for predict stage.')

def update(self):
super().update()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,53 +1,5 @@
from typing import TYPE_CHECKING
import logging

import numpy as np

from rastervision.core.rv_pipeline import RVPipeline

if TYPE_CHECKING:
from rastervision.core.data import (
Labels,
Scene,
)
from rastervision.core.rv_pipeline.semantic_segmentation_config import (
SemanticSegmentationConfig)

log = logging.getLogger(__name__)


class SemanticSegmentation(RVPipeline):
def post_process_batch(self, windows, chips, labels):
# Fill in null class for any NODATA pixels.
null_class_id = self.config.dataset.class_config.null_class_id
for window, chip in zip(windows, chips):
nodata_mask = np.sum(chip, axis=2) == 0
labels.mask_fill(window, nodata_mask, fill_value=null_class_id)

return labels

def predict_scene(self, scene: 'Scene') -> 'Labels':
if self.backend is None:
self.build_backend()

cfg: 'SemanticSegmentationConfig' = self.config
chip_sz = cfg.predict_chip_sz
stride = cfg.predict_options.stride
crop_sz = cfg.predict_options.crop_sz

if stride is None:
stride = chip_sz

if crop_sz == 'auto':
overlap_sz = chip_sz - stride
if overlap_sz % 2 == 1:
log.warning(
'Using crop_sz="auto" but overlap size (chip_sz minus '
'stride) is odd. This means that one pixel row/col will '
'still overlap after cropping.')
crop_sz = overlap_sz // 2

labels = self.backend.predict_scene(
scene, chip_sz=chip_sz, stride=stride, crop_sz=crop_sz)
labels = self.post_process_predictions(labels, scene)
return labels
pass

0 comments on commit 9fd4dd5

Please sign in to comment.