Skip to content

Commit

Permalink
More Training Parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
James McClain committed Jul 26, 2018
1 parent f4c765a commit 7455763
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
31 changes: 24 additions & 7 deletions src/rastervision/ml_backends/tf_deeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def merge_tf_records(output_path, src_records):
print()


def make_debug_images(record_path, output_dir):
def make_debug_images(record_path, output_dir): # XXX
make_dir(output_dir, check_empty=True)

print('Generating debug chips', end='', flush=True)
Expand All @@ -69,9 +69,9 @@ def make_debug_images(record_path, output_dir):
im, labels = parse_tfexample(example)
output_path = join(output_dir, '{}.png'.format(ind))
inv_labels = (labels == 0)
im[:, :, 0] = im[:, :, 0] * inv_labels # XXX
im[:, :, 1] = im[:, :, 1] * inv_labels # XXX
im[:, :, 2] = im[:, :, 2] * inv_labels # XXX
im[:, :, 0] = im[:, :, 0] * inv_labels
im[:, :, 1] = im[:, :, 1] * inv_labels
im[:, :, 2] = im[:, :, 2] * inv_labels
save_img(im, output_path)
print('.', end='', flush=True)
print()
Expand Down Expand Up @@ -170,17 +170,34 @@ def process_sceneset_results(self, training_results, validation_results,
shutil.make_archive(validation_zip_path, 'zip', debug_dir)

def train(self, class_map, options):
soptions = options.segmentation_options

train_logdir = options.output_uri
dataset_dir = options.training_data_uri
train_py = options.segmentation_options.train_py
tf_initial_checkpoints = \
options.segmentation_options.tf_initial_checkpoint
train_py = soptions.train_py
tf_initial_checkpoints = soptions.tf_initial_checkpoint

args = ['python', train_py]
args.append('--train_logdir={}'.format(train_logdir))
args.append(
'--tf_initial_checkpoint={}'.format(tf_initial_checkpoints))
args.append('--dataset_dir={}'.format(dataset_dir))
args.append('--training_number_of_steps={}'.format(
soptions.training_number_of_steps))
if len(soptions.train_split) > 0:
args.append('--train_split="{}"'.format(soptions.train_split))
if len(soptions.model_variant) > 0:
args.append('--model_variant="{}"'.format(soptions.model_variant))
for rate in soptions.atrous_rates:
args.append('--atrous_rates={}'.format(rate))
args.append('--output_stride={}'.format(soptions.output_stride))
args.append('--decoder_output_stride={}'.format(
soptions.decoder_output_stride))
for size in soptions.train_crop_size:
args.append('--train_crop_size={}'.format(size))
args.append('--train_batch_size={}'.format(soptions.train_batch_size))
if len(soptions.dataset):
args.append('--dataset="{}"'.format(soptions.dataset))

train_process = Popen(args)
terminate_at_exit(train_process)
Expand Down
2 changes: 1 addition & 1 deletion src/rastervision/protos/predict.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ message PredictConfig {
oneof ml_options_type {
ObjectDetectionOptions object_detection_options = 5;
ClassificationOptions classification_options = 6;
SegmentationOptions segmentation_options = 8;
SegmentationOptions segmentation_options = 9;
}
}

Expand Down
9 changes: 9 additions & 0 deletions src/rastervision/protos/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ message TrainConfig {
message SegmentationOptions {
optional string train_py = 1 [default="/opt/tf-models/deeplab/train.py"];
required string tf_initial_checkpoint = 2;
optional int32 training_number_of_steps = 3 [default=1];
optional string train_split = 4;
optional string model_variant = 5;
repeated int32 atrous_rates = 6;
optional int32 output_stride = 7 [default=16];
optional int32 decoder_output_stride = 8 [default=4];
repeated int32 train_crop_size = 9;
optional int32 train_batch_size = 10 [default=1];
optional string dataset = 11;
}

message Options {
Expand Down

0 comments on commit 7455763

Please sign in to comment.