Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation: Chips and Training #337

Merged
merged 23 commits into from Aug 14, 2018
Merged
2 changes: 1 addition & 1 deletion scripts/run
Expand Up @@ -35,7 +35,7 @@ do
key="$1"
case $key in
--aws)
AWS="-e AWS_PROFILE=${AWS_PROFILE} -v ${HOME}/.aws:/root/.aws:ro"
AWS="-e AWS_PROFILE=${AWS_PROFILE:-default} -v ${HOME}/.aws:/root/.aws:ro"
shift # past argument
;;
--tensorboard)
Expand Down
12 changes: 12 additions & 0 deletions src/rastervision/builders/label_store_builder.py
Expand Up @@ -2,6 +2,8 @@
ObjectDetectionGeoJSONFile)
from rastervision.label_stores.classification_geojson_file import (
ClassificationGeoJSONFile)
from rastervision.label_stores.segmentation_raster_file import (
SegmentationRasterFile)


def build(config,
Expand All @@ -28,3 +30,13 @@ def build(config,
extent,
readable=readable,
writable=writable)
elif label_store_type == 'segmentation_raster_file':
return SegmentationRasterFile(
src=config.segmentation_raster_file.src,
dst=config.segmentation_raster_file.dst,
src_classes=config.segmentation_raster_file.src_classes,
dst_classes=config.segmentation_raster_file.dst_classes)
return None
else:
raise ValueError('Not sure how to generate label store for type {}'
.format(label_store_type))
29 changes: 21 additions & 8 deletions src/rastervision/builders/ml_task_builder.py
Expand Up @@ -2,7 +2,9 @@
TFObjectDetectionAPI)
from rastervision.ml_tasks.object_detection import ObjectDetection
from rastervision.ml_backends.keras_classification import KerasClassification
from rastervision.ml_backends.tf_deeplab import TFDeeplab
from rastervision.ml_tasks.classification import Classification
from rastervision.ml_tasks.semantic_segmentation import SemanticSegmentation
from rastervision.protos.machine_learning_pb2 import MachineLearning
from rastervision.core.class_map import ClassItem, ClassMap

Expand All @@ -19,19 +21,30 @@ def build(config):
object_detection_val = \
MachineLearning.Task.Value('OBJECT_DETECTION')

tf_deeplab_val = \
MachineLearning.Backend.Value('TF_DEEPLAB')
semantic_segmentation_val = \
MachineLearning.Task.Value('SEMANTIC_SEGMENTATION')

keras_classification_val = \
MachineLearning.Backend.Value('KERAS_CLASSIFICATION')
classification_val = \
MachineLearning.Task.Value('CLASSIFICATION')

if config.backend == tf_object_detection_api_val:
backend = TFObjectDetectionAPI()
elif config.backend == keras_classification_val:
backend = KerasClassification()
backend_map = {
tf_object_detection_api_val: TFObjectDetectionAPI,
tf_deeplab_val: TFDeeplab,
keras_classification_val: KerasClassification
}

task_map = {
object_detection_val: ObjectDetection,
semantic_segmentation_val: SemanticSegmentation,
classification_val: Classification
}

if config.task == object_detection_val:
task = ObjectDetection(backend, class_map)
elif config.task == classification_val:
task = Classification(backend, class_map)
# XXX backend_map and task_map may need to become a cross-product
backend = (backend_map[config.backend])()
task = (task_map[config.task])(backend, class_map)

return task
38 changes: 38 additions & 0 deletions src/rastervision/contrib/cowc/transfer_georeference.py
@@ -0,0 +1,38 @@
#!/usr/bin/env python

import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had trouble running this in the Docker container, I think due to a Python version mismatch. But, I've realized there's no need to add georeferencing to the label files. We can just use ImageFile instead of GeoTiffFiles as the RasterSource to handle non-georeferenced imagery. Here is the relevant part of the workflow config I used to get it to work:

{
    "train_scenes": [
        {
            "id": "2-10",
            "raster_source": {
                "geotiff_files": {
                    "uris": [
                        "{raw}/isprs-potsdam/4_Ortho_RGBIR/top_potsdam_2_10_RGBIR.tif"
                    ]
                }
            },
            "ground_truth_label_store": {
                "segmentation_raster_file": {
                    "src": {
                        "image_file": {
                            "uri": "{raw}/isprs-potsdam/5_Labels_for_participants_no_Boundary/top_potsdam_2_10_label_noBoundary.tif"
                        }
                    },
                    "src_classes": [ "#ffff00", "#0000ff" ],
                    "dst_classes": [ 1, 0 ]
                }
            }
        }
    ],
    "test_scenes": [
        {
            "id": "2-11",
            "raster_source": {
                "geotiff_files": {
                    "uris": [
                        "{raw}/isprs-potsdam/4_Ortho_RGBIR/top_potsdam_2_11_RGBIR.tif"
                    ]
                }
            },
            "ground_truth_label_store": {
                "segmentation_raster_file": {
                    "src": {
                        "image_file": {
                            "uri": "{raw}/isprs-potsdam/5_Labels_for_participants_no_Boundary/top_potsdam_2_11_label_noBoundary.tif"
                        }
                    },
                    "src_classes": [ "#ffff00", "#0000ff" ],
                    "dst_classes": [ 1, 0 ]
                }
            }
        }
    ],

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am running in the container, as well. I am not sure how a version mis-match could occur. Edit: I see now that you are talking about the georeferencing script. I'll address that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does image_file work when there is more than one label raster for the scene? If there is only one label raster per scene, does that imply that there is only one image raster per scene?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image_file does not work if there is more than one label raster for the scene. I'm guessing that case won't come up very frequently, but it's possible. For this dataset, it's not a problem.
Re: your second question, I think it's possible for there to be a single label raster, but multiple image rasters per scene.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is going to be a "getting started" example, we should probably use image_file and avoid having to run the georeferencing script, although it could be useful for another application.

Here's what I got when I ran the script in the container after running update:
root@e820085104e2:/opt/src# ./rastervision/contrib/cowc/transfer_georeference.py \

/opt/data/raw-data/isprs-potsdam/4_Ortho_RGBIR/top_potsdam_2_10_RGBIR.tif \
/opt/data/raw-data/isprs-potsdam/5_Labels_for_participants_no_Boundary/top_potsdam_2_10_label_noBoundary.tif \
/opt/data/raw-data/isprs-potsdam/labels/2_10.tif

Traceback (most recent call last):
File "./rastervision/contrib/cowc/transfer_georeference.py", line 27, in
ul = re.search(ul_re, ullr)
File "/usr/lib/python3.5/re.py", line 173, in search
return _compile(pattern, flags).search(string)
TypeError: cannot use a string pattern on a bytes-like object

import re
import sys

from subprocess import check_output, call

if not len(sys.argv) >= 4:
print('Usage: {} <input_rgb.tif> <input_label.tif> <output_label.tif>'.
format(sys.argv[0]))
exit()

input_rgb = sys.argv[1]
input_label = sys.argv[2]
output_label = sys.argv[3]

# Get proj4 string
proj4 = check_output(['gdalsrsinfo', '-o', 'proj4', input_rgb], stderr=None)
proj4 = proj4[1:-2]

# Get upper left, lower right info
with open(os.devnull, 'w') as devnull:
ullr = check_output(['gdalinfo', input_rgb], stderr=devnull)
ul_re = re.compile(r'^Upper Left.*?([0-9\.]+).*?([0-9\.]+)', re.MULTILINE)
lr_re = re.compile(r'^Lower Right.*?([0-9\.]+).*?([0-9\.]+)', re.MULTILINE)
ul = re.search(ul_re, ullr)
lr = re.search(lr_re, ullr)

args = [
'gdal_translate', '-a_srs', proj4, '-a_ullr',
ul.group(1),
ul.group(2),
lr.group(1),
lr.group(2), input_label, output_label
]

call(args)
4 changes: 4 additions & 0 deletions src/rastervision/core/class_map.py
Expand Up @@ -43,6 +43,10 @@ def get_by_name(self, name):
return item
raise ValueError('{} is not a name in this ClassMap.'.format(name))

def get_keys(self):
"""Return the keys."""
return list(self.class_item_map.keys())

def get_items(self):
"""Return list of ClassItems."""
return list(self.class_item_map.values())
Expand Down
157 changes: 157 additions & 0 deletions src/rastervision/label_stores/segmentation_raster_file.py
@@ -0,0 +1,157 @@
import numpy as np

from typing import (List, Union)
from PIL import ImageColor

from rastervision.core.box import Box
from rastervision.core.label_store import LabelStore
from rastervision.core.raster_source import RasterSource
from rastervision.builders import raster_source_builder
from rastervision.protos.raster_source_pb2 import (RasterSource as
RasterSourceProto)

RasterUnion = Union[RasterSource, RasterSourceProto, str, None]


class SegmentationRasterFile(LabelStore):
"""A label store for segmentation raster files.

"""

def __init__(self,
src: RasterUnion,
dst: RasterUnion,
src_classes: List[str] = [],
dst_classes: List[int] = []):
"""Constructor.

Args:
src: A source of raster label data (either an object that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand the difference between src and dst now. I was confused because in the other LabelStores I was just using uri for both reading and writing (but had separate readable and writable fields). But I still think it's confusing that the term dst is being used in both dst, and dst_classes which seems like a different concept. Perhaps dst_classes should be called rv_classes since they are the class ids that RV is using internally, and they are used even when there is no dst specified.

can provide it or a path).
dst: A destination for raster label data.
src_classes: A list of integer classes found in the label
source. These are zipped with the destination list
to produce a correspondence between input classes
and output classes.
dst_classes: A list of integer classes found in the
labels that are to be produced. These labels should
match those given in the workflow configuration file
(the class map).

"""
self.set_labels(src)

if isinstance(dst, RasterSource):
self.dst = dst
elif isinstance(dst, RasterSourceProto):
self.dst = raster_source_builder.build(dst)
elif dst is None or dst is '':
self.dst = None
elif isinstance(dst, str):
pass # XXX seeing str instead of RasterSourceProto
else:
raise ValueError('Unsure how to handle dst={}'.format(type(dst)))

def color_to_integer(color: str) -> int:
"""Given a PIL ImageColor string, return a packed integer.

Args:
color: A PIL ImageColor string

Returns:
An integer containing the packed RGB values.

"""
try:
triple = ImageColor.getrgb(color)
except ValueError:
r = np.random.randint(0, 256)
g = np.random.randint(0, 256)
b = np.random.randint(0, 256)
triple = (r, g, b)

r = triple[0] * (1 << 16)
g = triple[1] * (1 << 8)
b = triple[2] * (1 << 0)
integer = r + g + b
return integer

src_classes = list(map(color_to_integer, src_classes))
correspondence = dict(zip(src_classes, dst_classes))

def src_to_dst(n: int) -> int:
"""Translate source classes to destination class.

args:
n: A source class represented as a packed rgb pixel
(an integer in the range 0 to 2**24-1).

Returns:
The destination class as an integer.

"""
if n in correspondence:
return correspondence.get(n)
else:
return 0

self.src_to_dst = np.vectorize(src_to_dst)

def clear(self):
"""Clear all labels."""
self.src = None

def set_labels(self, src: RasterUnion) -> None:
"""Set labels, overwriting any that existed prior to this call.

Args:
src: A source of raster label data (either an object that
can provide it or a path).

Returns:
None

"""
if isinstance(src, RasterSource):
self.src = src
elif isinstance(src, RasterSourceProto):
self.src = raster_source_builder.build(src)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like how this is smart enough to handle different types. Seems like a good API design pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

elif src is None:
self.src = None
else:
raise ValueError('Unsure how to handle src={}'.format(type(src)))

if self.src is not None:
small_box = Box(0, 0, 1, 1)
(self.channels, _, _) = self.src._get_chip(small_box).shape
else:
self.channels = 1

def get_labels(self, window: Box) -> np.ndarray:
"""Get labels from a window.

If self.src is not None then a label window is clipped from
it. If self.src is None then return an appropriatly shaped
np.ndarray of zeros.

Args:
window: A window given as a Box object.

Returns:
np.ndarray

"""
if self.src is not None:
return self.src._get_chip(window)
else:
ymin = window.ymin
xmin = window.xmin
ymax = window.ymax
xmax = window.xmax
return np.zeros((ymax - ymin, xmax - xmin, self.channels))

def extend(self, labels):
pass

def save(self):
pass
57 changes: 57 additions & 0 deletions src/rastervision/label_stores/segmentation_raster_file_test.py
@@ -0,0 +1,57 @@
import unittest
import numpy as np

from rastervision.core.raster_source import RasterSource
from rastervision.core.box import Box
from rastervision.label_stores.segmentation_raster_file import (
SegmentationRasterFile)


class TestingRasterSource(RasterSource):
def __init__(self, zeros=False):
self.width = 4
self.height = 4
self.channels = 3
if zeros:
self.data = np.zeros((self.channels, self.height, self.width))
elif not zeros:
self.data = np.random.rand(self.channels, self.height, self.width)

def get_extent(self):
return Box(0, 0, self.height, self.width)

def _get_chip(self, window):
ymin = window.ymin
xmin = window.xmin
ymax = window.ymax
xmax = window.xmax
return self.data[:, ymin:ymax, xmin:xmax]

def get_chip(self, window):
return self.get_chip(window)

def get_crs_transformer(self, window):
return None


class TestSegmentationRasterFile(unittest.TestCase):
def test_clear(self):
label_store = SegmentationRasterFile(TestingRasterSource(), None)
extent = label_store.src.get_extent()
label_store.clear()
data = label_store.get_labels(extent)
self.assertEqual(data.sum(), 0)

def test_set_labels(self):
raster_source = TestingRasterSource()
label_store = SegmentationRasterFile(
TestingRasterSource(zeros=True), None)
label_store.set_labels(raster_source)
extent = label_store.src.get_extent()
rs_data = raster_source._get_chip(extent)
ls_data = label_store.get_labels(extent)
self.assertEqual(rs_data.sum(), ls_data.sum())


if __name__ == '__main__':
unittest.main()