Skip to content

Commit

Permalink
Merge pull request #464 from azavea/lf/rasterizer-config
Browse files Browse the repository at this point in the history
Add rasterizer options
  • Loading branch information
lossyrob committed Oct 7, 2018
2 parents cd3461b + a36cd61 commit 1683364
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rastervision.data.utils import geojson_to_shapes


def infer_cell(str_tree, cell, ioa_thresh, use_intersection_over_cell,
def infer_cell(shapes, cell, ioa_thresh, use_intersection_over_cell,
background_class_id, pick_min_class_id):
"""Infer the class_id of a cell given a set of polygons.
Expand All @@ -22,7 +22,7 @@ def infer_cell(str_tree, cell, ioa_thresh, use_intersection_over_cell,
considered null or background. See args for more details.
Args:
str_tree: shapely.strtree.STRtree of shapely geometry with class_id attributes
shapes: List of (shapely.geometry, class_id) tuples
cell: Box
ioa_thresh: (float) the minimum IOA of a polygon and cell for that
polygon to be a candidate for setting the class_id
Expand All @@ -37,6 +37,14 @@ def infer_cell(str_tree, cell, ioa_thresh, use_intersection_over_cell,
class_id of the boxes in that cell. Otherwise, pick the class_id of
the box covering the greatest area.
"""
str_tree = STRtree([shape for shape, class_id in shapes])
# Monkey-patching class_id onto shapely.geom is not a good idea because
# if you transform it, the class_id will be lost, but this works here. I wanted to
# use a dictionary to associate shape with class_id, but couldn't because they are
# mutable.
for shape, class_id in shapes:
shape.class_id = class_id

cell_geom = geometry.Polygon(
[(p[0], p[1]) for p in cell.geojson_coordinates()])
intersecting_polygons = str_tree.query(cell_geom)
Expand Down Expand Up @@ -92,16 +100,16 @@ def infer_labels(geojson, crs_transformer, extent, cell_size, ioa_thresh,
ChipClassificationLabels
"""
shapes = geojson_to_shapes(geojson, crs_transformer)
for shape in shapes:
# TODO: handle linestrings
for shape, class_id in shapes:
if type(shape) != geometry.Polygon:
raise ValueError(
'Chip classification can only handle geoms of type Polygon')
str_tree = STRtree(shapes)
labels = ChipClassificationLabels()

cells = extent.get_windows(cell_size, cell_size)
for cell in cells:
class_id = infer_cell(str_tree, cell, ioa_thresh,
class_id = infer_cell(shapes, cell, ioa_thresh,
use_intersection_over_cell, background_class_id,
pick_min_class_id)
labels.set_cell(cell, class_id)
Expand Down
44 changes: 27 additions & 17 deletions src/rastervision/data/raster_source/geojson_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,48 @@
from rastervision.data.utils import geojson_to_shapes


def geojson_to_raster(geojson, extent, crs_transformer):
# TODO: make this configurable
line_buffer = 15
def geojson_to_raster(geojson, rasterizer_options, extent, crs_transformer):
line_buffer = rasterizer_options.line_buffer
background_class_id = rasterizer_options.background_class_id

# Crop shapes against extent and remove empty shapes.
shapes = geojson_to_shapes(geojson, crs_transformer)
shapes = [s.intersection(extent.to_shapely()) for s in shapes]
shapes = [s for s in shapes if not s.is_empty]
shapes = [
s.buffer(line_buffer) if type(s) is shapely.geometry.LineString else s
for s in shapes
]

# TODO: make this configurable
# Map background to class 1 and shapes to class 2.
shape_vals = [(shape, 2) for shape in shapes]
shapes = [(s.intersection(extent.to_shapely()), c) for s, c in shapes]
shapes = [(s, c) for s, c in shapes if not s.is_empty]
shapes = [(s.buffer(line_buffer), c)
if type(s) is shapely.geometry.LineString else (s, c)
for s, c in shapes]

out_shape = (extent.get_height(), extent.get_width())
if shapes:
raster = rasterize(shape_vals, out_shape=out_shape, fill=1)
raster = rasterize(
shapes, out_shape=out_shape, fill=background_class_id)
else:
raster = np.ones(out_shape)

return raster


class GeoJSONSource(RasterSource):
def __init__(self, uri, extent, crs_transformer):
"""A RasterSource based on the rasterization of a GeoJSON file."""

def __init__(self, uri, rasterizer_options, extent, crs_transformer):
"""Constructor.
Args:
uri: URI of GeoJSON file
rasterizer_options:
rastervision.data.raster_source.GeoJSONSourceConfig.RasterizerOptions
extent: (Box) extent of corresponding imagery RasterSource
crs_transformer: (CRSTransformer)
"""
self.uri = uri
self.rasterizer_options = rasterizer_options
self.extent = extent
self.crs_transformer = crs_transformer
geojson_dict = json.loads(file_to_str(self.uri))
self.raster = geojson_to_raster(geojson_dict, extent, crs_transformer)
geojson = json.loads(file_to_str(self.uri))
self.raster = geojson_to_raster(geojson, rasterizer_options, extent,
crs_transformer)
# Add third singleton dim since rasters must have >=1 channel.
self.raster = np.expand_dims(self.raster, 2)
super().__init__()
Expand Down
65 changes: 58 additions & 7 deletions src/rastervision/data/raster_source/geojson_source_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,42 @@


class GeoJSONSourceConfig(RasterSourceConfig):
def __init__(self, uri, transformers=None, channel_order=None):
super().__init__(source_type=rv.GEOJSON_SOURCE)
class RasterizerOptions(object):
def __init__(self, background_class_id, line_buffer=15):
"""Constructor.
Args:
background_class_id: The class_id to use for background pixels that don't
overlap with any shapes in the GeoJSON file.
line_buffer: Number of pixels to add to each side of line when rasterized.
"""
self.background_class_id = background_class_id
self.line_buffer = line_buffer

def to_proto(self):
return RasterSourceConfigMsg.GeoJSONFile.RasterizerOptions(
background_class_id=self.background_class_id,
line_buffer=self.line_buffer)

def __init__(self,
uri,
rasterizer_options,
transformers=None,
channel_order=None):
super().__init__(
source_type=rv.GEOJSON_SOURCE,
transformers=transformers,
channel_order=channel_order)
self.uri = uri
self.rasterizer_options = rasterizer_options

def to_proto(self):
msg = super().to_proto()
msg.MergeFrom(
RasterSourceConfigMsg(
geojson_file=RasterSourceConfigMsg.GeoJSONFile(uri=self.uri)))
geojson_file=RasterSourceConfigMsg.GeoJSONFile(
uri=self.uri,
rasterizer_options=self.rasterizer_options.to_proto())))
return msg

def save_bundle_files(self, bundle_dir):
Expand All @@ -41,7 +68,8 @@ def create_local(self, tmp_dir):
.build()

def create_source(self, tmp_dir, extent, crs_transformer):
return GeoJSONSource(self.uri, extent, crs_transformer)
return GeoJSONSource(self.uri, self.rasterizer_options, extent,
crs_transformer)

def preprocess_command(self, command_type, experiment_config,
context=None):
Expand All @@ -56,7 +84,10 @@ class GeoJSONSourceConfigBuilder(RasterSourceConfigBuilder):
def __init__(self, prev=None):
config = {}
if prev:
config = {'uri': prev.uri}
config = {
'uri': prev.uri,
'rasterizer_options': prev.rasterizer_options
}

super().__init__(GeoJSONSourceConfig, config)

Expand All @@ -67,13 +98,33 @@ def validate(self):
'You must specify a uri for the GeoJSONSourceConfig. Use "with_uri"'
)

if self.config.get('rasterizer_options') is None:
raise rv.ConfigError(
'You must configure the rasterizer for the GeoJSONSourceConfig. '
'Use "with_rasterizer_options"')

def from_proto(self, msg):
b = super().from_proto(msg)

return b \
.with_uri(msg.geojson_file.uri)
.with_uri(msg.geojson_file.uri) \
.with_rasterizer_options(
msg.geojson_file.rasterizer_options.background_class_id,
msg.geojson_file.rasterizer_options.line_buffer)

def with_uri(self, uri):
b = deepcopy(self)
b.config['uri'] = uri
return b

def with_rasterizer_options(self, background_class_id, line_buffer=15):
"""Specify options for converting GeoJSON to raster.
Args:
background_class_id: The class_id to use for background pixels that don't
overlap with any shapes in the GeoJSON file.
line_buffer: Number of pixels to add to each side of line when rasterized.
"""
b = deepcopy(self)
b.config['rasterizer_options'] = GeoJSONSourceConfig.RasterizerOptions(
background_class_id, line_buffer=line_buffer)
return b
17 changes: 7 additions & 10 deletions src/rastervision/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import shapely


def geojson_to_shapes(geojson_dict, crs_transformer):
def geojson_to_shapes(geojson, crs_transformer):
"""Convert GeoJSON into list of shapely.geometry shape.
Args:
geojson_dict: dict in GeoJSON format with class_id property for each
geojson: dict in GeoJSON format with class_id property for each
feature (class_id defaults to 1 if missing)
crs_transformer: CRSTransformer used to convert from map to pixel
coords
Returns:
List of shapely.geometry with .class_id attributes
List of (shapely.geometry, class_id) tuples
"""
features = geojson_dict['features']
features = geojson['features']
shapes = []

for feature in features:
Expand All @@ -28,20 +28,17 @@ def geojson_to_shapes(geojson_dict, crs_transformer):
shape = [crs_transformer.map_to_pixel(p) for p in shell]
# Trick to handle self-intersecting polygons using buffer(0)
shape = shapely.geometry.Polygon(shape).buffer(0)
shape.class_id = class_id
shapes.append(shape)
shapes.append((shape, class_id))
elif geom_type == 'Polygon':
shell = coordinates[0]
shape = [crs_transformer.map_to_pixel(p) for p in shell]
# Trick to handle self-intersecting polygons using buffer(0)
shape = shapely.geometry.Polygon(shape).buffer(0)
shape.class_id = class_id
shapes.append(shape)
shapes.append((shape, class_id))
elif geom_type == 'LineString':
shape = [crs_transformer.map_to_pixel(p) for p in coordinates]
shape = shapely.geometry.LineString(shape)
shape.class_id = class_id
shapes.append(shape)
shapes.append((shape, class_id))
else:
# TODO: logging warning that this type can't be parsed.
pass
Expand Down
9 changes: 9 additions & 0 deletions src/rastervision/protos/raster_source.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@ message RasterSourceConfig {

// Used to read a GeoJSON file as a raster useful for semantic segmentation.
message GeoJSONFile {
message RasterizerOptions {
// The class_id to use for background pixels that don't overlap with any
// shapes in the GeoJSON file.
required int32 background_class_id = 2;

// Number of pixels to add to each side of line when rasterized.
optional int32 line_buffer = 3 [default=15];
}
required string uri = 1;
required RasterizerOptions rasterizer_options = 2;
}

required string source_type = 1;
Expand Down

0 comments on commit 1683364

Please sign in to comment.