New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Segmentation: Chips and Training #337
Changes from 16 commits
1b56683
50228d2
73f281e
001dc3f
8d522ca
d3cbf5e
1371c14
e9d1fc8
8dc7dbc
ebe87c9
8a95b18
f1a0163
edd6b64
4c005b9
e9edb58
521df6d
bc69e23
14e062b
2f81754
bedae75
c934ef7
7e47b35
825e900
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/usr/bin/env python | ||
|
||
import os | ||
import re | ||
import sys | ||
|
||
from subprocess import check_output, call | ||
|
||
if not len(sys.argv) >= 4: | ||
print('Usage: {} <input_rgb.tif> <input_label.tif> <output_label.tif>'. | ||
format(sys.argv[0])) | ||
exit() | ||
|
||
input_rgb = sys.argv[1] | ||
input_label = sys.argv[2] | ||
output_label = sys.argv[3] | ||
|
||
# Get proj4 string | ||
proj4 = check_output(['gdalsrsinfo', '-o', 'proj4', input_rgb], stderr=None) | ||
proj4 = proj4[1:-2] | ||
|
||
# Get upper left, lower right info | ||
with open(os.devnull, 'w') as devnull: | ||
ullr = check_output(['gdalinfo', input_rgb], stderr=devnull) | ||
ul_re = re.compile(r'^Upper Left.*?([0-9\.]+).*?([0-9\.]+)', re.MULTILINE) | ||
lr_re = re.compile(r'^Lower Right.*?([0-9\.]+).*?([0-9\.]+)', re.MULTILINE) | ||
ul = re.search(ul_re, ullr) | ||
lr = re.search(lr_re, ullr) | ||
|
||
args = [ | ||
'gdal_translate', '-a_srs', proj4, '-a_ullr', | ||
ul.group(1), | ||
ul.group(2), | ||
lr.group(1), | ||
lr.group(2), input_label, output_label | ||
] | ||
|
||
call(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import numpy as np | ||
|
||
from typing import (List, Union) | ||
from PIL import ImageColor | ||
|
||
from rastervision.core.box import Box | ||
from rastervision.core.label_store import LabelStore | ||
from rastervision.core.raster_source import RasterSource | ||
from rastervision.builders import raster_source_builder | ||
from rastervision.protos.raster_source_pb2 import (RasterSource as | ||
RasterSourceProto) | ||
|
||
RasterUnion = Union[RasterSource, RasterSourceProto, str, None] | ||
|
||
|
||
class SegmentationRasterFile(LabelStore): | ||
"""A label store for segmentation raster files. | ||
|
||
""" | ||
|
||
def __init__(self, | ||
src: RasterUnion, | ||
dst: RasterUnion, | ||
src_classes: List[str] = [], | ||
dst_classes: List[int] = []): | ||
"""Constructor. | ||
|
||
Args: | ||
src: A source of raster label data (either an object that | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I understand the difference between |
||
can provide it or a path). | ||
dst: A destination for raster label data. | ||
src_classes: A list of integer classes found in the label | ||
source. These are zipped with the destination list | ||
to produce a correspondence between input classes | ||
and output classes. | ||
dst_classes: A list of integer classes found in the | ||
labels that are to be produced. These labels should | ||
match those given in the workflow configuration file | ||
(the class map). | ||
|
||
""" | ||
self.set_labels(src) | ||
|
||
if isinstance(dst, RasterSource): | ||
self.dst = dst | ||
elif isinstance(dst, RasterSourceProto): | ||
self.dst = raster_source_builder.build(dst) | ||
elif dst is None or dst is '': | ||
self.dst = None | ||
elif isinstance(dst, str): | ||
pass # XXX seeing str instead of RasterSourceProto | ||
else: | ||
raise ValueError('Unsure how to handle dst={}'.format(type(dst))) | ||
|
||
def color_to_integer(color: str) -> int: | ||
"""Given a PIL ImageColor string, return a packed integer. | ||
|
||
Args: | ||
color: A PIL ImageColor string | ||
|
||
Returns: | ||
An integer containing the packed RGB values. | ||
|
||
""" | ||
try: | ||
triple = ImageColor.getrgb(color) | ||
except ValueError: | ||
r = np.random.randint(0, 256) | ||
g = np.random.randint(0, 256) | ||
b = np.random.randint(0, 256) | ||
triple = (r, g, b) | ||
|
||
r = triple[0] * (1 << 16) | ||
g = triple[1] * (1 << 8) | ||
b = triple[2] * (1 << 0) | ||
integer = r + g + b | ||
return integer | ||
|
||
src_classes = list(map(color_to_integer, src_classes)) | ||
correspondence = dict(zip(src_classes, dst_classes)) | ||
|
||
def src_to_dst(n: int) -> int: | ||
"""Translate source classes to destination class. | ||
|
||
args: | ||
n: A source class represented as a packed rgb pixel | ||
(an integer in the range 0 to 2**24-1). | ||
|
||
Returns: | ||
The destination class as an integer. | ||
|
||
""" | ||
if n in correspondence: | ||
return correspondence.get(n) | ||
else: | ||
return 0 | ||
|
||
self.src_to_dst = np.vectorize(src_to_dst) | ||
|
||
def clear(self): | ||
"""Clear all labels.""" | ||
self.src = None | ||
|
||
def set_labels(self, src: RasterUnion) -> None: | ||
"""Set labels, overwriting any that existed prior to this call. | ||
|
||
Args: | ||
src: A source of raster label data (either an object that | ||
can provide it or a path). | ||
|
||
Returns: | ||
None | ||
|
||
""" | ||
if isinstance(src, RasterSource): | ||
self.src = src | ||
elif isinstance(src, RasterSourceProto): | ||
self.src = raster_source_builder.build(src) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like how this is smart enough to handle different types. Seems like a good API design pattern. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks |
||
elif src is None: | ||
self.src = None | ||
else: | ||
raise ValueError('Unsure how to handle src={}'.format(type(src))) | ||
|
||
if self.src is not None: | ||
small_box = Box(0, 0, 1, 1) | ||
(self.channels, _, _) = self.src._get_chip(small_box).shape | ||
else: | ||
self.channels = 1 | ||
|
||
def get_labels(self, window: Box) -> np.ndarray: | ||
"""Get labels from a window. | ||
|
||
If self.src is not None then a label window is clipped from | ||
it. If self.src is None then return an appropriatly shaped | ||
np.ndarray of zeros. | ||
|
||
Args: | ||
window: A window given as a Box object. | ||
|
||
Returns: | ||
np.ndarray | ||
|
||
""" | ||
if self.src is not None: | ||
return self.src._get_chip(window) | ||
else: | ||
ymin = window.ymin | ||
xmin = window.xmin | ||
ymax = window.ymax | ||
xmax = window.xmax | ||
return np.zeros((ymax - ymin, xmax - xmin, self.channels)) | ||
|
||
def extend(self, labels): | ||
pass | ||
|
||
def save(self): | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import unittest | ||
import numpy as np | ||
|
||
from rastervision.core.raster_source import RasterSource | ||
from rastervision.core.box import Box | ||
from rastervision.label_stores.segmentation_raster_file import ( | ||
SegmentationRasterFile) | ||
|
||
|
||
class TestingRasterSource(RasterSource): | ||
def __init__(self, zeros=False): | ||
self.width = 4 | ||
self.height = 4 | ||
self.channels = 3 | ||
if zeros: | ||
self.data = np.zeros((self.channels, self.height, self.width)) | ||
elif not zeros: | ||
self.data = np.random.rand(self.channels, self.height, self.width) | ||
|
||
def get_extent(self): | ||
return Box(0, 0, self.height, self.width) | ||
|
||
def _get_chip(self, window): | ||
ymin = window.ymin | ||
xmin = window.xmin | ||
ymax = window.ymax | ||
xmax = window.xmax | ||
return self.data[:, ymin:ymax, xmin:xmax] | ||
|
||
def get_chip(self, window): | ||
return self.get_chip(window) | ||
|
||
def get_crs_transformer(self, window): | ||
return None | ||
|
||
|
||
class TestSegmentationRasterFile(unittest.TestCase): | ||
def test_clear(self): | ||
label_store = SegmentationRasterFile(TestingRasterSource(), None) | ||
extent = label_store.src.get_extent() | ||
label_store.clear() | ||
data = label_store.get_labels(extent) | ||
self.assertEqual(data.sum(), 0) | ||
|
||
def test_set_labels(self): | ||
raster_source = TestingRasterSource() | ||
label_store = SegmentationRasterFile( | ||
TestingRasterSource(zeros=True), None) | ||
label_store.set_labels(raster_source) | ||
extent = label_store.src.get_extent() | ||
rs_data = raster_source._get_chip(extent) | ||
ls_data = label_store.get_labels(extent) | ||
self.assertEqual(rs_data.sum(), ls_data.sum()) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had trouble running this in the Docker container, I think due to a Python version mismatch. But, I've realized there's no need to add georeferencing to the label files. We can just use
ImageFile
instead ofGeoTiffFiles
as theRasterSource
to handle non-georeferenced imagery. Here is the relevant part of the workflow config I used to get it to work:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am running in the container, as well. I am not sure how a version mis-match could occur.Edit: I see now that you are talking about the georeferencing script. I'll address that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does
image_file
work when there is more than one label raster for the scene? If there is only one label raster per scene, does that imply that there is only one image raster per scene?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
image_file
does not work if there is more than one label raster for the scene. I'm guessing that case won't come up very frequently, but it's possible. For this dataset, it's not a problem.Re: your second question, I think it's possible for there to be a single label raster, but multiple image rasters per scene.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is going to be a "getting started" example, we should probably use
image_file
and avoid having to run the georeferencing script, although it could be useful for another application.Here's what I got when I ran the script in the container after running
update
:root@e820085104e2:/opt/src# ./rastervision/contrib/cowc/transfer_georeference.py \
Traceback (most recent call last):
File "./rastervision/contrib/cowc/transfer_georeference.py", line 27, in
ul = re.search(ul_re, ullr)
File "/usr/lib/python3.5/re.py", line 173, in search
return _compile(pattern, flags).search(string)
TypeError: cannot use a string pattern on a bytes-like object