Skip to content

Commit

Permalink
Export and TensorBoard
Browse files Browse the repository at this point in the history
  • Loading branch information
James McClain committed Aug 2, 2018
1 parent f86c3af commit 101c91d
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 33 deletions.
2 changes: 1 addition & 1 deletion scripts/run
Expand Up @@ -35,7 +35,7 @@ do
key="$1"
case $key in
--aws)
AWS="-e AWS_PROFILE=${AWS_PROFILE} -v ${HOME}/.aws:/root/.aws:ro"
AWS="-e AWS_PROFILE=${AWS_PROFILE:-default} -v ${HOME}/.aws:/root/.aws:ro"
shift # past argument
;;
--tensorboard)
Expand Down
66 changes: 46 additions & 20 deletions src/rastervision/ml_backends/tf_deeplab.py
Expand Up @@ -25,8 +25,8 @@
write_tf_record, terminate_at_exit, TRAIN, VALIDATION)
from rastervision.utils.misc import save_img
from rastervision.utils.files import (get_local_path, upload_if_needed,
make_dir, download_if_needed,
sync_dir, start_sync)
make_dir, download_if_needed, sync_dir,
start_sync)


def numpy_to_png(array: np.ndarray) -> str:
Expand Down Expand Up @@ -220,6 +220,13 @@ def get_record_uri(uri: str, split: str) -> str:
return join(uri, '{}-0.record'.format(split))


def get_latest_checkpoint(train_logdir_local: str) -> str:
ckpts = glob.glob(join(train_logdir_local, 'model.ckpt-*.meta'))
times = map(os.path.getmtime, ckpts)
latest = sorted(zip(times, ckpts))[-1][1]
return latest[:len(latest) - len('.meta')]


class TFDeeplab(MLBackend):
"""MLBackend-derived type that implements the TensorFlow DeepLab
backend.
Expand Down Expand Up @@ -330,6 +337,7 @@ def train(self, class_map: ClassMap, options) -> None:
soptions = options.segmentation_options

train_py = soptions.train_py
export_model_py = soptions.export_model_py

# Setup local input and output directories
train_logdir = options.output_uri
Expand All @@ -338,17 +346,7 @@ def train(self, class_map: ClassMap, options) -> None:
dataset_dir_local = get_local_path(dataset_dir, self.temp_dir)
make_dir(train_logdir_local)
make_dir(dataset_dir_local)

download_if_needed(get_record_uri(dataset_dir, TRAIN), self.temp_dir)
# XXX
# Inspite of the prohibition, it might make sense to log
# directly to s3 in the remote case. The commented-out code
# below does not work because the absolute path seems to be
# hard-coded into the state, and that (potentially) changes
# run-to-run due to the use of a temporary directory with a
# random name.
# if urlparse(train_logdir).scheme == 's3':
# sync_dir(train_logdir, train_logdir_local, delete=True)

# Download and untar initial checkpoint.
tf_initial_checkpoints_uri = soptions.tf_initial_checkpoints_uri
Expand All @@ -364,28 +362,41 @@ def train(self, class_map: ClassMap, options) -> None:
# Build array of argments that will be used to run the DeepLab
# training script.
args = ['python', train_py]

args.append('--train_logdir={}'.format(train_logdir_local))
args.append('--tf_initial_checkpoint={}'.format(tfic_index))
args.append('--dataset_dir={}'.format(dataset_dir_local))
args.append('--training_number_of_steps={}'.format(
soptions.training_number_of_steps))

steps = soptions.training_number_of_steps
if steps > 0:
args.append('--training_number_of_steps={}'.format(steps))

if len(soptions.train_split) > 0:
args.append('--train_split="{}"'.format(soptions.train_split))
args.append('--train_split={}'.format(soptions.train_split))

if len(soptions.model_variant) > 0:
args.append('--model_variant="{}"'.format(soptions.model_variant))
args.append('--model_variant={}'.format(soptions.model_variant))

for rate in soptions.atrous_rates:
args.append('--atrous_rates={}'.format(rate))

args.append('--output_stride={}'.format(soptions.output_stride))
args.append('--decoder_output_stride={}'.format(
soptions.decoder_output_stride))

for size in soptions.train_crop_size:
args.append('--train_crop_size={}'.format(size))

args.append('--train_batch_size={}'.format(soptions.train_batch_size))

if len(soptions.dataset):
args.append('--dataset="{}"'.format(soptions.dataset))

# XXX
# See the block comment above regarding train_logdir.
args.append('--save_interval_secs={}'.format(soptions.save_interval_secs))
args.append('--save_summaries_secs={}'.format(soptions.save_summaries_secs))
args.append('--save_summaries_images={}'.format(soptions.save_summaries_images))

# Periodically synchronize with remote
start_sync(
train_logdir_local,
train_logdir,
Expand All @@ -394,13 +405,28 @@ def train(self, class_map: ClassMap, options) -> None:
# Train
train_process = Popen(args)
terminate_at_exit(train_process)
tensorboard_process = Popen(
['tensorboard', '--logdir={}'.format(train_logdir_local)])
terminate_at_exit(tensorboard_process)
train_process.wait()
tensorboard_process.terminate()

# Build array of arguments that will be used to run the DeepLab
# export script.
args = ['python', export_model_py]
args.append('--checkpoint_path={}'.format(
get_latest_checkpoint(train_logdir_local)))
args.append('--export_path={}'.format(
join(train_logdir_local, 'frozen_inference_graph.pb')))

# Export
export_process = Popen(args)
terminate_at_exit(export_process)
export_process.wait()

if urlparse(train_logdir).scheme == 's3':
sync_dir(train_logdir_local, train_logdir, delete=True)

# XXX tensorboard

def predict(self, chip, options):
import pdb
pdb.set_trace()
Expand Down
25 changes: 15 additions & 10 deletions src/rastervision/protos/train.proto
Expand Up @@ -19,16 +19,21 @@ message TrainConfig {

message SegmentationOptions {
optional string train_py = 1 [default="/opt/tf-models/deeplab/train.py"];
required string tf_initial_checkpoints_uri = 2;
optional int32 training_number_of_steps = 3 [default=1];
optional string train_split = 4;
optional string model_variant = 5;
repeated int32 atrous_rates = 6;
optional int32 output_stride = 7 [default=16];
optional int32 decoder_output_stride = 8 [default=4];
repeated int32 train_crop_size = 9;
optional int32 train_batch_size = 10 [default=1];
optional string dataset = 11;
optional string export_model_py = 2 [default="/opt/tf-models/deeplab/export_model.py"];
required string tf_initial_checkpoints_uri = 3;
optional int32 training_number_of_steps = 4 [default=0];
optional string train_split = 5;
optional string model_variant = 6;
repeated int32 atrous_rates = 7;
optional int32 output_stride = 8 [default=16];
optional int32 decoder_output_stride = 9 [default=4];
repeated int32 train_crop_size = 10;
optional int32 train_batch_size = 11 [default=1];
optional string dataset = 12;
optional int32 log_steps = 13 [default=10];
optional int32 save_interval_secs = 14 [default=1200];
optional int32 save_summaries_secs = 15 [default=29];
optional bool save_summaries_images = 16 [default=true];
}

message Options {
Expand Down
Expand Up @@ -68,7 +68,13 @@
"backend_config_uri": "/dev/null",
"sync_interval": 600,
"segmentation_options": {
"tf_initial_checkpoints_uri": "{rv_root}/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz"
"tf_initial_checkpoints_uri": "{rv_root}/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz",
"training_number_of_steps": 1
"train_split": "train",
"model_variant": "mobilenet_v2",
"atrous_rates": [6, 12, 18],
"decoder_output_stride": 4,
"train_batch_size": 12
}
},
"predict_options": {
Expand Down
Expand Up @@ -68,7 +68,13 @@
"backend_config_uri": "/dev/null",
"sync_interval": 600,
"segmentation_options": {
"tf_initial_checkpoints_uri": "/opt/data/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz"
"tf_initial_checkpoints_uri": "/opt/data/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz",
"training_number_of_steps": 0,
"train_split": "train",
"model_variant": "mobilenet_v2",
"atrous_rates": [6, 12, 18],
"decoder_output_stride": 4,
"train_batch_size": 2
}
},
"predict_options": {
Expand Down

0 comments on commit 101c91d

Please sign in to comment.