Skip to content

Commit

Permalink
Merge pull request #472 from azavea/lf/aoi
Browse files Browse the repository at this point in the history
Handle AOI polygons for segmentation
  • Loading branch information
lewfish committed Oct 9, 2018
2 parents eb90e30 + 012d402 commit fefe423
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 13 deletions.
15 changes: 10 additions & 5 deletions rastervision/data/label/semantic_segmentation_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from rastervision.core.box import Box

import numpy as np
from rasterio.features import rasterize


class SemanticSegmentationLabels(Labels):
Expand Down Expand Up @@ -35,11 +36,16 @@ def __add__(self, other):
def filter_by_aoi(self, aoi_polygons):
"""Returns a copy of these labels filtered by a given set of AOI polygons
Converts values that lie outside of aoi_polygons to 0 (ie. don't care class)
Args:
aoi_polygons - A list of AOI polygons to filter by, in pixel coordinates.
"""
# TODO implement this
pass
arr = self.to_array()
mask = rasterize(
[(p, 1) for p in aoi_polygons], out_shape=arr.shape, fill=0)
arr = arr * mask
return SemanticSegmentationLabels.from_array(arr)

def add_label_pair(self, window, label_array):
self.label_pairs.append((window, label_array))
Expand All @@ -55,10 +61,9 @@ def get_extent(self):

def to_array(self):
extent = self.get_extent()
arr = np.zeros((extent.get_height(), extent.get_width()))
arr = np.zeros((extent.ymax, extent.ymax))
for window, label_array in self.label_pairs:
arr[window.ymin:window.ymax, window.xmin:window.xmax] = \
label_array
arr[window.ymin:window.ymax, window.xmin:window.xmax] = label_array
return arr

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, class_map):
color_to_class.values()))

def color_int_to_class_fn(color: int) -> int:
# Convert unspecified colors to class 0 which is "don't care"
return color_int_to_class.get(color, 0x00)

self.transform_color_int_to_class = \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,11 @@ def save(self, labels):
crs = self.crs_transformer.get_image_crs()
clipped_labels = labels.get_clipped_labels(self.extent)

band_count = 1
dtype = np.int32
if self.class_trans:
band_count = 3
else:
band_count = 1

if self.class_trans:
dtype = np.uint8
else:
dtype = np.int32

# https://github.com/mapbox/rasterio/blob/master/docs/quickstart.rst
# https://rasterio.readthedocs.io/en/latest/topics/windowed-rw.html
Expand Down
5 changes: 5 additions & 0 deletions rastervision/evaluation/semantic_segmentation_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def compute(self, ground_truth_labels, prediction_labels):
# http://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html # noqa
ground_truth_labels = ground_truth_labels.to_array()
prediction_labels = prediction_labels.to_array()
# This shouldn't happen, but just in case...
if ground_truth_labels.shape != prediction_labels.shape:
raise ValueError(
'ground_truth_labels and prediction_labels need to '
'have the same shape.')

evaluation_items = []
for class_id in self.class_map.get_keys():
Expand Down
14 changes: 12 additions & 2 deletions tests/data/label/test_semantic_segmentation_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ def test_from_array(self):

def test_to_array(self):
arr = self.labels.to_array()
expected_array = np.zeros((200, 200))
np.testing.assert_array_equal(arr, expected_array)
expected_arr = np.zeros((200, 200))
np.testing.assert_array_equal(arr, expected_arr)

def test_filter_by_aoi(self):
arr = np.ones((5, 5))
labels = SemanticSegmentationLabels.from_array(arr)
aoi_polygons = [Box.make_square(0, 0, 2).to_shapely()]
labels = labels.filter_by_aoi(aoi_polygons)
arr = labels.to_array()
expected_arr = np.zeros((5, 5))
expected_arr[0:2, 0:2] = 1
np.testing.assert_array_equal(arr, expected_arr)


if __name__ == '__main__':
Expand Down

0 comments on commit fefe423

Please sign in to comment.