Skip to content

Commit

Permalink
Class Map Colors
Browse files Browse the repository at this point in the history
  • Loading branch information
James McClain committed Aug 7, 2018
1 parent 3c9aa2b commit 5617426
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 38 deletions.
39 changes: 33 additions & 6 deletions src/rastervision/label_stores/segmentation_raster_file.py
@@ -1,6 +1,7 @@
import numpy as np

from typing import (List, Union)
from PIL import ImageColor

from rastervision.core.box import Box
from rastervision.core.label_store import LabelStore
Expand All @@ -20,7 +21,7 @@ class SegmentationRasterFile(LabelStore):
def __init__(self,
src: RasterUnion,
dst: RasterUnion,
src_classes: List[int] = [],
src_classes: List[str] = [],
dst_classes: List[int] = []):
"""Constructor.
Expand Down Expand Up @@ -51,16 +52,42 @@ def __init__(self,
else:
raise ValueError('Unsure how to handle dst={}'.format(type(dst)))

# XXX move to utils
def color_to_integer(color: str) -> int:
try:
triple = ImageColor.getrgb(color)
except ValueError:
r = np.random.randint(0, 256)
g = np.random.randint(0, 256)
b = np.random.randint(0, 256)
triple = (r, g, b)

r = triple[0] * (1 << 16)
g = triple[1] * (1 << 8)
b = triple[2] * (1 << 0)
integer = r + g + b
return integer

src_classes = list(map(color_to_integer, src_classes))
correspondence = dict(zip(src_classes, dst_classes))

def fn(n):
def src_to_dst(n: int) -> int:
"""Translate source classes to destination class.
args:
n: A source class represented as a packed rgb pixel
(an integer in the range 0 to 2**24-1).
Returns:
The destination class as an integer.
"""
if n in correspondence:
return correspondence.get(n)
else:
return n
return 0

self.fn = np.vectorize(fn)
self.correspondence = correspondence
self.src_to_dst = np.vectorize(src_to_dst)

def clear(self):
"""Clear all labels."""
Expand Down Expand Up @@ -102,7 +129,7 @@ def get_labels(self, window: Union[Box, None] = None) -> np.ndarray:
returned for the entire scene.
Returns:
numpy.ndarray
np.ndarray
"""
if self.src is not None:
Expand Down
79 changes: 65 additions & 14 deletions src/rastervision/ml_backends/tf_deeplab.py
Expand Up @@ -9,7 +9,7 @@
import uuid

from os.path import join
from PIL import Image
from PIL import Image, ImageColor
from subprocess import Popen
from tensorflow.core.example.example_pb2 import Example
from typing import (List, Tuple)
Expand Down Expand Up @@ -107,7 +107,31 @@ def merge_tf_records(output_path: str, src_records: List[str]) -> None:
print()


def make_debug_images(record_path: str, output_dir: str,
def string_to_triple(color: str) -> np.ndarray:
"""Turn a PIL colorstring into an RGB triple.
Args:
color: A PIL color string
Returns:
An np.ndarray of shape (1,1,3) whether derived from the
string or chosen randomly.
"""
try:
(r, g, b) = ImageColor.getrgb(color)
except AttributeError:
r = np.random.randint(0, 256)
g = np.random.randint(0, 256)
b = np.random.randint(0, 256)
(r, g, b)

return np.array([r, g, b], dtype=np.uint16)


def make_debug_images(record_path: str,
output_dir: str,
class_map: ClassMap,
p: float = 0.25) -> None:
"""Render a random sample of the TFRecords in a given file as
human-viewable PNG files.
Expand All @@ -117,24 +141,48 @@ def make_debug_images(record_path: str, output_dir: str,
output_dir: Destination directory for the generated PNG files.
p: The probability of rendering a particular record.
Return:
Returns:
None
"""
make_dir(output_dir, check_empty=True)

def composite(arr: np.ndarray, *args) -> np.ndarray:
"""Composite the image with the labels.
args:
arr: An np.ndarray of shape (4,) where the first three
entries contains visual data and the fourth contains
a label datum.
*args: Ignored
Returns:
An np.ndarray of shape (1,1,3) where the label datum has
been composited into the visual data using color
information from the color_map variable which has been
captured from the environment.
"""
label = arr[3]
if label == 0:
return arr[0:3]
else:
color = class_map.get_by_id(label).color
label_rgb = string_to_triple(color)
image_rgb = np.array(arr[0:3], dtype=np.uint16)
return np.array((label_rgb + image_rgb) / 2, dtype=np.uint8)

print('Generating debug chips', end='', flush=True)
tfrecord_iter = tf.python_io.tf_record_iterator(record_path)
for ind, example in enumerate(tfrecord_iter):
example = tf.train.Example.FromString(example)
im, labels = parse_tf_example(example)
output_path = join(output_dir, '{}.png'.format(ind))
inv_labels = (labels == 0)
im[:, :, 0] = im[:, :, 0] * inv_labels
im[:, :, 1] = im[:, :, 1] * inv_labels
im[:, :, 2] = im[:, :, 2] * inv_labels
if np.random.rand() <= p:
save_img(im, output_path)
example = tf.train.Example.FromString(example)
im, labels = parse_tf_example(example)
labels3 = labels[:, :, np.newaxis]
im_labels = np.concatenate([im, labels3], axis=2)
output_path = join(output_dir, '{}.png'.format(ind))
composited = np.apply_along_axis(composite, 2, im_labels)
save_img(composited, output_path)
print('.', end='', flush=True)
print()

Expand All @@ -145,8 +193,9 @@ def parse_tf_example(example: Example) -> Tuple[np.ndarray, np.ndarray]:
Args:
example: A TensorFlow Example object.
Return:
Returns:
tuple(np.ndarray, np.ndarray)
"""
ie = 'image/encoded'
isce = 'image/segmentation/class/encoded'
Expand Down Expand Up @@ -311,10 +360,12 @@ def process_sceneset_results(self, training_results: List[str],
self.temp_dir)

with tempfile.TemporaryDirectory() as debug_dir:
make_debug_images(training_record_path_local, debug_dir)
make_debug_images(training_record_path_local, debug_dir,
class_map)
shutil.make_archive(training_zip_path_local, 'zip', debug_dir)
with tempfile.TemporaryDirectory() as debug_dir:
make_debug_images(validation_record_path_local, debug_dir)
make_debug_images(validation_record_path_local, debug_dir,
class_map)
shutil.make_archive(validation_zip_path_local, 'zip',
debug_dir)
upload_if_needed('{}.zip'.format(training_zip_path_local),
Expand Down
34 changes: 24 additions & 10 deletions src/rastervision/ml_tasks/semantic_segmentation.py
Expand Up @@ -21,7 +21,7 @@ def get_train_windows(self, scene: Scene, options) -> List[Box]:
`make_training_chips` section of the workflow
configuration file.
Return:
Returns:
A list of windows, list(Box)
"""
Expand All @@ -38,14 +38,28 @@ def get_train_windows(self, scene: Scene, options) -> List[Box]:

return windows

def get_train_labels(self, window, scene, options):
label_store = scene.ground_truth_label_store
chip = label_store.src._get_chip(window)
fn = label_store.fn
def get_train_labels(self, window: Box, scene: Scene,
options) -> np.ndarray:
"""Get the training labels for the given window in the given scene.
Args:
window: The window over-which the labels are to be
retrieved.
scene: The scene from-which the window of labels is to be
extracted.
options: Options passed through from the
`make_training_chips` section of the workflow
configuration file.
bit2 = (chip[:, :, 0] > 0)
bit1 = (chip[:, :, 1] > 0)
bit0 = (chip[:, :, 2] > 0)
retval = np.array(bit2 * 4 + bit1 * 2 + bit0 * 1, dtype=np.uint8)
Returns:
An appropriately-shaped 2d np.ndarray containing the labels.
return np.array(fn(retval), dtype=np.uint8)
"""
label_store = scene.ground_truth_label_store
chip = label_store.src._get_chip(window)
r = np.array(chip[:, :, 0], dtype=np.uint32) * (1 << 16)
g = np.array(chip[:, :, 1], dtype=np.uint32) * (1 << 8)
b = np.array(chip[:, :, 2], dtype=np.uint32) * (1 << 0)
packed = r + g + b
retval = np.array(label_store.src_to_dst(packed), dtype=np.uint8)
return retval
2 changes: 1 addition & 1 deletion src/rastervision/protos/label_store.proto
Expand Up @@ -42,7 +42,7 @@ message ClassificationGeoJSONFile {
message SegmentationRasterFile {
optional RasterSource src = 1;
optional string dst = 2;
repeated uint32 src_classes = 3 [packed=true];
repeated string src_classes = 3;
repeated int32 dst_classes = 4 [packed=true];
}

Expand Down
Expand Up @@ -18,7 +18,7 @@
]
}
},
"src_classes": [ 6, 1 ],
"src_classes": [ "#ffff00", "#0000ff" ],
"dst_classes": [ 1, 0 ]
}
}
Expand All @@ -43,8 +43,8 @@
]
}
},
"src_classes": [ 6 ],
"dst_classes": [ 1 ]
"src_classes": [ "#ffff00", "#0000ff" ],
"dst_classes": [ 1, 0 ]
}
}
}
Expand Down
Expand Up @@ -18,7 +18,7 @@
]
}
},
"src_classes": [ 6, 1 ],
"src_classes": [ "#ffff00", "#0000ff" ],
"dst_classes": [ 1, 0 ]
}
}
Expand All @@ -43,8 +43,8 @@
]
}
},
"src_classes": [ 6 ],
"dst_classes": [ 1 ]
"src_classes": [ "#ffff00", "#0000ff" ],
"dst_classes": [ 1, 0 ]
}
}
}
Expand All @@ -55,7 +55,8 @@
"class_items": [
{
"id": 1,
"name": "car"
"name": "car",
"color": "#ffff00"
}
]
},
Expand Down

0 comments on commit 5617426

Please sign in to comment.