From 5617426997ad018b06327691e01f8a55aafdc038 Mon Sep 17 00:00:00 2001 From: James McClain Date: Mon, 6 Aug 2018 11:45:59 -0400 Subject: [PATCH] Class Map Colors --- .../label_stores/segmentation_raster_file.py | 39 +++++++-- src/rastervision/ml_backends/tf_deeplab.py | 79 +++++++++++++++---- .../ml_tasks/semantic_segmentation.py | 34 +++++--- src/rastervision/protos/label_store.proto | 2 +- .../segmentation/deeplab-test-remote.json | 6 +- .../segmentation/deeplab-test.json | 9 ++- 6 files changed, 131 insertions(+), 38 deletions(-) diff --git a/src/rastervision/label_stores/segmentation_raster_file.py b/src/rastervision/label_stores/segmentation_raster_file.py index 967aed1d87..76c94ba10b 100644 --- a/src/rastervision/label_stores/segmentation_raster_file.py +++ b/src/rastervision/label_stores/segmentation_raster_file.py @@ -1,6 +1,7 @@ import numpy as np from typing import (List, Union) +from PIL import ImageColor from rastervision.core.box import Box from rastervision.core.label_store import LabelStore @@ -20,7 +21,7 @@ class SegmentationRasterFile(LabelStore): def __init__(self, src: RasterUnion, dst: RasterUnion, - src_classes: List[int] = [], + src_classes: List[str] = [], dst_classes: List[int] = []): """Constructor. @@ -51,16 +52,42 @@ def __init__(self, else: raise ValueError('Unsure how to handle dst={}'.format(type(dst))) + # XXX move to utils + def color_to_integer(color: str) -> int: + try: + triple = ImageColor.getrgb(color) + except ValueError: + r = np.random.randint(0, 256) + g = np.random.randint(0, 256) + b = np.random.randint(0, 256) + triple = (r, g, b) + + r = triple[0] * (1 << 16) + g = triple[1] * (1 << 8) + b = triple[2] * (1 << 0) + integer = r + g + b + return integer + + src_classes = list(map(color_to_integer, src_classes)) correspondence = dict(zip(src_classes, dst_classes)) - def fn(n): + def src_to_dst(n: int) -> int: + """Translate source classes to destination class. + + args: + n: A source class represented as a packed rgb pixel + (an integer in the range 0 to 2**24-1). + + Returns: + The destination class as an integer. + + """ if n in correspondence: return correspondence.get(n) else: - return n + return 0 - self.fn = np.vectorize(fn) - self.correspondence = correspondence + self.src_to_dst = np.vectorize(src_to_dst) def clear(self): """Clear all labels.""" @@ -102,7 +129,7 @@ def get_labels(self, window: Union[Box, None] = None) -> np.ndarray: returned for the entire scene. Returns: - numpy.ndarray + np.ndarray """ if self.src is not None: diff --git a/src/rastervision/ml_backends/tf_deeplab.py b/src/rastervision/ml_backends/tf_deeplab.py index 4ec7636a83..91a6bb02e1 100644 --- a/src/rastervision/ml_backends/tf_deeplab.py +++ b/src/rastervision/ml_backends/tf_deeplab.py @@ -9,7 +9,7 @@ import uuid from os.path import join -from PIL import Image +from PIL import Image, ImageColor from subprocess import Popen from tensorflow.core.example.example_pb2 import Example from typing import (List, Tuple) @@ -107,7 +107,31 @@ def merge_tf_records(output_path: str, src_records: List[str]) -> None: print() -def make_debug_images(record_path: str, output_dir: str, +def string_to_triple(color: str) -> np.ndarray: + """Turn a PIL colorstring into an RGB triple. + + Args: + color: A PIL color string + + Returns: + An np.ndarray of shape (1,1,3) whether derived from the + string or chosen randomly. + + """ + try: + (r, g, b) = ImageColor.getrgb(color) + except AttributeError: + r = np.random.randint(0, 256) + g = np.random.randint(0, 256) + b = np.random.randint(0, 256) + (r, g, b) + + return np.array([r, g, b], dtype=np.uint16) + + +def make_debug_images(record_path: str, + output_dir: str, + class_map: ClassMap, p: float = 0.25) -> None: """Render a random sample of the TFRecords in a given file as human-viewable PNG files. @@ -117,24 +141,48 @@ def make_debug_images(record_path: str, output_dir: str, output_dir: Destination directory for the generated PNG files. p: The probability of rendering a particular record. - Return: + Returns: None """ make_dir(output_dir, check_empty=True) + def composite(arr: np.ndarray, *args) -> np.ndarray: + """Composite the image with the labels. + + args: + arr: An np.ndarray of shape (4,) where the first three + entries contains visual data and the fourth contains + a label datum. + *args: Ignored + + Returns: + An np.ndarray of shape (1,1,3) where the label datum has + been composited into the visual data using color + information from the color_map variable which has been + captured from the environment. + + """ + label = arr[3] + if label == 0: + return arr[0:3] + else: + color = class_map.get_by_id(label).color + label_rgb = string_to_triple(color) + image_rgb = np.array(arr[0:3], dtype=np.uint16) + return np.array((label_rgb + image_rgb) / 2, dtype=np.uint8) + print('Generating debug chips', end='', flush=True) tfrecord_iter = tf.python_io.tf_record_iterator(record_path) for ind, example in enumerate(tfrecord_iter): - example = tf.train.Example.FromString(example) - im, labels = parse_tf_example(example) - output_path = join(output_dir, '{}.png'.format(ind)) - inv_labels = (labels == 0) - im[:, :, 0] = im[:, :, 0] * inv_labels - im[:, :, 1] = im[:, :, 1] * inv_labels - im[:, :, 2] = im[:, :, 2] * inv_labels if np.random.rand() <= p: - save_img(im, output_path) + example = tf.train.Example.FromString(example) + im, labels = parse_tf_example(example) + labels3 = labels[:, :, np.newaxis] + im_labels = np.concatenate([im, labels3], axis=2) + output_path = join(output_dir, '{}.png'.format(ind)) + composited = np.apply_along_axis(composite, 2, im_labels) + save_img(composited, output_path) print('.', end='', flush=True) print() @@ -145,8 +193,9 @@ def parse_tf_example(example: Example) -> Tuple[np.ndarray, np.ndarray]: Args: example: A TensorFlow Example object. - Return: + Returns: tuple(np.ndarray, np.ndarray) + """ ie = 'image/encoded' isce = 'image/segmentation/class/encoded' @@ -311,10 +360,12 @@ def process_sceneset_results(self, training_results: List[str], self.temp_dir) with tempfile.TemporaryDirectory() as debug_dir: - make_debug_images(training_record_path_local, debug_dir) + make_debug_images(training_record_path_local, debug_dir, + class_map) shutil.make_archive(training_zip_path_local, 'zip', debug_dir) with tempfile.TemporaryDirectory() as debug_dir: - make_debug_images(validation_record_path_local, debug_dir) + make_debug_images(validation_record_path_local, debug_dir, + class_map) shutil.make_archive(validation_zip_path_local, 'zip', debug_dir) upload_if_needed('{}.zip'.format(training_zip_path_local), diff --git a/src/rastervision/ml_tasks/semantic_segmentation.py b/src/rastervision/ml_tasks/semantic_segmentation.py index 374255c209..6d4b000761 100644 --- a/src/rastervision/ml_tasks/semantic_segmentation.py +++ b/src/rastervision/ml_tasks/semantic_segmentation.py @@ -21,7 +21,7 @@ def get_train_windows(self, scene: Scene, options) -> List[Box]: `make_training_chips` section of the workflow configuration file. - Return: + Returns: A list of windows, list(Box) """ @@ -38,14 +38,28 @@ def get_train_windows(self, scene: Scene, options) -> List[Box]: return windows - def get_train_labels(self, window, scene, options): - label_store = scene.ground_truth_label_store - chip = label_store.src._get_chip(window) - fn = label_store.fn + def get_train_labels(self, window: Box, scene: Scene, + options) -> np.ndarray: + """Get the training labels for the given window in the given scene. + + Args: + window: The window over-which the labels are to be + retrieved. + scene: The scene from-which the window of labels is to be + extracted. + options: Options passed through from the + `make_training_chips` section of the workflow + configuration file. - bit2 = (chip[:, :, 0] > 0) - bit1 = (chip[:, :, 1] > 0) - bit0 = (chip[:, :, 2] > 0) - retval = np.array(bit2 * 4 + bit1 * 2 + bit0 * 1, dtype=np.uint8) + Returns: + An appropriately-shaped 2d np.ndarray containing the labels. - return np.array(fn(retval), dtype=np.uint8) + """ + label_store = scene.ground_truth_label_store + chip = label_store.src._get_chip(window) + r = np.array(chip[:, :, 0], dtype=np.uint32) * (1 << 16) + g = np.array(chip[:, :, 1], dtype=np.uint32) * (1 << 8) + b = np.array(chip[:, :, 2], dtype=np.uint32) * (1 << 0) + packed = r + g + b + retval = np.array(label_store.src_to_dst(packed), dtype=np.uint8) + return retval diff --git a/src/rastervision/protos/label_store.proto b/src/rastervision/protos/label_store.proto index bc5765a52b..a86df3b6a1 100644 --- a/src/rastervision/protos/label_store.proto +++ b/src/rastervision/protos/label_store.proto @@ -42,7 +42,7 @@ message ClassificationGeoJSONFile { message SegmentationRasterFile { optional RasterSource src = 1; optional string dst = 2; - repeated uint32 src_classes = 3 [packed=true]; + repeated string src_classes = 3; repeated int32 dst_classes = 4 [packed=true]; } diff --git a/src/rastervision/samples/workflow-configs/segmentation/deeplab-test-remote.json b/src/rastervision/samples/workflow-configs/segmentation/deeplab-test-remote.json index e45078b327..21976ad74c 100644 --- a/src/rastervision/samples/workflow-configs/segmentation/deeplab-test-remote.json +++ b/src/rastervision/samples/workflow-configs/segmentation/deeplab-test-remote.json @@ -18,7 +18,7 @@ ] } }, - "src_classes": [ 6, 1 ], + "src_classes": [ "#ffff00", "#0000ff" ], "dst_classes": [ 1, 0 ] } } @@ -43,8 +43,8 @@ ] } }, - "src_classes": [ 6 ], - "dst_classes": [ 1 ] + "src_classes": [ "#ffff00", "#0000ff" ], + "dst_classes": [ 1, 0 ] } } } diff --git a/src/rastervision/samples/workflow-configs/segmentation/deeplab-test.json b/src/rastervision/samples/workflow-configs/segmentation/deeplab-test.json index deaa38630e..65202b054d 100644 --- a/src/rastervision/samples/workflow-configs/segmentation/deeplab-test.json +++ b/src/rastervision/samples/workflow-configs/segmentation/deeplab-test.json @@ -18,7 +18,7 @@ ] } }, - "src_classes": [ 6, 1 ], + "src_classes": [ "#ffff00", "#0000ff" ], "dst_classes": [ 1, 0 ] } } @@ -43,8 +43,8 @@ ] } }, - "src_classes": [ 6 ], - "dst_classes": [ 1 ] + "src_classes": [ "#ffff00", "#0000ff" ], + "dst_classes": [ 1, 0 ] } } } @@ -55,7 +55,8 @@ "class_items": [ { "id": 1, - "name": "car" + "name": "car", + "color": "#ffff00" } ] },