Skip to content

Commit

Permalink
Greater Configurability
Browse files Browse the repository at this point in the history
  • Loading branch information
James McClain committed Aug 14, 2018
1 parent c934ef7 commit 7e47b35
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/rastervision/label_stores/segmentation_raster_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def src_to_rv(n: int) -> int:
if n in src_to_rv_class_map:
return src_to_rv_class_map.get(n)
else:
return 0
return 0x00

self.src_to_rv = np.vectorize(src_to_rv, otypes=[np.uint8])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


class TestingRasterSource(RasterSource):

def __init__(self, zeros=False):
self.width = 4
self.height = 4
Expand Down Expand Up @@ -42,6 +43,7 @@ def get_crs_transformer(self, window):


class TestSegmentationRasterFile(unittest.TestCase):

def test_clear(self):
label_store = SegmentationRasterFile(TestingRasterSource(), None)
extent = label_store.src.get_extent()
Expand All @@ -58,7 +60,7 @@ def test_set_labels(self):
label_store.set_labels(raster_source)
extent = label_store.src.get_extent()
rs_data = raster_source._get_chip(extent)
ls_data = label_store.get_labels(extent)
ls_data = (label_store.get_labels(extent) == 1)
self.assertEqual(rs_data.sum(), ls_data.sum())


Expand Down
24 changes: 17 additions & 7 deletions src/rastervision/ml_backends/tf_deeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,15 @@ def create_tf_example(image: np.ndarray,
"""
class_keys = set(class_map.get_keys())

def fn(n):
return (n if n in class_keys else 0)

filtered_labels = np.array(np.vectorize(fn)(labels), dtype=np.uint8)
def _clean(n):
return (n if n in class_keys else 0x00)
clean = np.vectorize(_clean, otypes=[np.uint8])

image_encoded = numpy_to_png(image)
image_filename = chip_id.encode('utf8')
image_format = 'png'.encode('utf8')
image_height, image_width, image_channels = image.shape
image_segmentation_class_encoded = numpy_to_png(filtered_labels)
image_segmentation_class_encoded = numpy_to_png(clean(labels))
image_segmentation_class_format = 'png'.encode('utf8')

features = tf.train.Features(
Expand Down Expand Up @@ -308,6 +307,9 @@ def get_training_args(train_py: str, train_logdir_local: str, tfic_index: str,
for size in be_options.train_crop_size:
args.append('--train_crop_size={}'.format(size))

if len(be_options.dataset) > 0:
args.append('--dataset={}'.format(be_options.dataset))

args.append('--train_logdir={}'.format(train_logdir_local))
args.append('--tf_initial_checkpoint={}'.format(tfic_index))
args.append('--dataset_dir={}'.format(dataset_dir_local))
Expand All @@ -317,14 +319,22 @@ def get_training_args(train_py: str, train_logdir_local: str, tfic_index: str,
args.append('--decoder_output_stride={}'.format(
be_options.decoder_output_stride))
args.append('--train_batch_size={}'.format(be_options.train_batch_size))
if len(be_options.dataset) > 0:
args.append('--dataset="{}"'.format(be_options.dataset))
args.append('--save_interval_secs={}'.format(
be_options.save_interval_secs))
args.append('--save_summaries_secs={}'.format(
be_options.save_summaries_secs))
args.append('--save_summaries_images={}'.format(
be_options.save_summaries_images))
args.append('--last_layer_gradient_multiplier={}'.format(
be_options.last_layer_gradient_multiplier))
args.append('--initialize_last_layer={}'.format(
be_options.initialize_last_layer))
args.append('--min_scale_factor={}'.format(be_options.min_scale_factor))
args.append('--max_scale_factor={}'.format(be_options.max_scale_factor))
args.append('--fine_tune_batch_norm={}'.format(
be_options.fine_tune_batch_norm))
args.append('--last_layers_contain_logits_only={}'.format(
be_options.last_layers_contain_logits_only))

return args

Expand Down
14 changes: 10 additions & 4 deletions src/rastervision/protos/deeplab/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@ message TrainingParameters {
optional string train_split = 5 [default = "train"];
required string model_variant = 6;
repeated int32 atrous_rates = 7;
optional int32 output_stride = 8 [default = 1];
optional int32 decoder_output_stride = 9 [default = 1];
optional int32 output_stride = 8 [default = 8];
optional int32 decoder_output_stride = 9 [default = 8];
repeated int32 train_crop_size = 10;
optional int32 train_batch_size = 11 [default = 1];
optional int32 train_batch_size = 11 [default = 8];
optional string dataset = 12 [default = ""];
optional int32 save_interval_secs = 13 [default = 600];
optional int32 save_summaries_secs = 14 [default = 30];
optional int32 save_summaries_secs = 14 [default = 5];
optional bool save_summaries_images = 15 [default = true];
optional float last_layer_gradient_multiplier = 16 [default = 1.0];
optional bool initialize_last_layer = 17 [default = true];
optional float min_scale_factor = 18 [default = 1.0];
optional float max_scale_factor = 19 [default = 1.0];
optional bool fine_tune_batch_norm = 20 [default = true];
optional bool last_layers_contain_logits_only = 21 [default = false];
}

0 comments on commit 7e47b35

Please sign in to comment.