Skip to content

Commit

Permalink
Merge pull request #653 from jamesmcclain/eval
Browse files Browse the repository at this point in the history
Eval Script
  • Loading branch information
jamesmcclain committed Jan 24, 2019
2 parents 7097791 + 8ea561f commit 66a1ef5
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 40 deletions.
1 change: 1 addition & 0 deletions docs/changelog.rst
Expand Up @@ -13,6 +13,7 @@ Raster Vision 0.9.0
- Remove custom ``__deepcopy__`` implementation from ``ConfigBuilder``s. `#567 <https://github.com/azavea/raster-vision/pull/567>`_
- Add ability to shift raster images by given numbers of meters. `#573 <https://github.com/azavea/raster-vision/pull/573>`_
- Add ability to generate GeoJSON segmentation predictions. `#575 <https://github.com/azavea/raster-vision/pull/575>`_
- Add ability to run the DeepLab eval script. `#653 <https://github.com/azavea/raster-vision/pull/653>`_
Raster Vision 0.8
-----------------
Expand Down
63 changes: 62 additions & 1 deletion rastervision/backend/tf_deeplab.py
Expand Up @@ -176,7 +176,7 @@ def create_tf_example(image: np.ndarray,
labels: np.ndarray,
class_map: ClassMap,
chip_id: str = ''):
"""Create a TensorFlow from an image, the labels, &c.
"""Create a TensorFlow Example from an image, the labels, &c.
Args:
image: An np.ndarray containing the image data.
Expand Down Expand Up @@ -265,6 +265,56 @@ def get_latest_checkpoint(train_logdir_local: str) -> str:
return latest[:len(latest) - len('.meta')]


def get_evaluation_args(eval_py: str, train_logdir_local: str,
dataset_dir_local: str, eval_logdir: str, tfdl_config):
"""Generate the array of arguments needed to run the eval script.
Args:
eval_py: The URI of the eval script.
train_logdir_local: The directory in-which checkpoints can be
found.
dataset_dir_local: The directory in which the records are
found.
eval_logdir: The directory where evaluation events should be
logged.
tfdl_config: google.protobuf.Struct with fields from
rv.protos.deeplab.train.proto containing TF Deeplab training configuration
Returns:
A list of arguments suitable for running the eval script.
"""
fields = [
'dataset',
'output_stride',
'decoder_output_stride',
'model_variant',
'eval_split',
]

multi_fields = [
'atrous_rates',
'eval_crop_size',
]

args = ['python', eval_py]

args.append('--checkpoint_dir={}'.format(train_logdir_local))
args.append('--eval_logdir={}'.format(eval_logdir))
args.append('--dataset_dir={}'.format(dataset_dir_local))

for field in multi_fields:
for item in tfdl_config.__getattribute__(field):
args.append('--{}={}'.format(field, item))

for field in fields:
field_value = tfdl_config.__getattribute__(field)
if (not type(field_value) is str) or (not len(field_value) == 0):
args.append('--{}={}'.format(field, field_value))

return args


def get_training_args(train_py: str, train_logdir_local: str, tfic_ckpt: str,
dataset_dir_local: str, num_classes: int,
tfdl_config) -> Tuple[List[str], Dict[str, str]]:
Expand Down Expand Up @@ -521,6 +571,7 @@ def train(self, tmp_dir: str) -> None:
None
"""
train_py = self.backend_config.script_locations.train_py
eval_py = self.backend_config.script_locations.eval_py
export_py = self.backend_config.script_locations.export_py

# Setup local input and output directories
Expand Down Expand Up @@ -599,6 +650,16 @@ def train(self, tmp_dir: str) -> None:
['tensorboard', '--logdir={}'.format(train_logdir_local)])
terminate_at_exit(tensorboard_process)

if self.backend_config.train_options.do_eval:
# Start eval script
log.info('Starting eval script')
eval_logdir = train_logdir_local
eval_args = get_evaluation_args(eval_py, train_logdir_local,
dataset_dir_local, eval_logdir,
tfdl_config)
eval_process = Popen(eval_args, env=train_env)
terminate_at_exit(eval_process)

# Wait for training and tensorboard
log.info('Waiting for training and tensorboard processes')
train_process.wait()
Expand Down
37 changes: 26 additions & 11 deletions rastervision/backend/tf_deeplab_config.py
Expand Up @@ -12,6 +12,7 @@

# Default location to Tensorflow Deeplab's scripts.
DEFAULT_SCRIPT_TRAIN = '/opt/tf-models/deeplab/train.py'
DEFAULT_SCRIPT_EVAL = '/opt/tf-models/deeplab/eval.py'
DEFAULT_SCRIPT_EXPORT = '/opt/tf-models/deeplab/export_model.py'
CHIP_OUTPUT_FILES = ['train-{}.record', 'validation-{}.record']
DEBUG_CHIP_OUTPUT_FILES = ['train.zip', 'validation.zip']
Expand All @@ -23,17 +24,21 @@ def __init__(self,
train_restart_dir=None,
sync_interval=600,
do_monitoring=True,
replace_model=False):
replace_model=False,
do_eval=False):
self.train_restart_dir = train_restart_dir
self.sync_interval = sync_interval
self.do_monitoring = do_monitoring
self.replace_model = replace_model
self.do_eval = do_eval

class ScriptLocations:
def __init__(self,
train_py=DEFAULT_SCRIPT_TRAIN,
export_py=DEFAULT_SCRIPT_EXPORT):
export_py=DEFAULT_SCRIPT_EXPORT,
eval_py=DEFAULT_SCRIPT_EVAL):
self.train_py = train_py
self.eval_py = eval_py
self.export_py = export_py

def __init__(self,
Expand Down Expand Up @@ -75,10 +80,12 @@ def create_backend(self, task_config):
def to_proto(self):
d = {
'train_py': self.script_locations.train_py,
'eval_py': self.script_locations.eval_py,
'export_py': self.script_locations.export_py,
'train_restart_dir': self.train_options.train_restart_dir,
'sync_interval': self.train_options.sync_interval,
'do_monitoring': self.train_options.do_monitoring,
'do_eval': self.train_options.do_eval,
'replace_model': self.train_options.replace_model,
'debug': self.debug,
'training_data_uri': self.training_data_uri,
Expand Down Expand Up @@ -192,9 +199,12 @@ def from_proto(self, msg):
train_restart_dir=conf.train_restart_dir,
sync_interval=conf.sync_interval,
do_monitoring=conf.do_monitoring,
replace_model=conf.replace_model)
replace_model=conf.replace_model,
do_eval=conf.do_eval)
b = b.with_script_locations(
train_py=conf.train_py, export_py=conf.export_py)
train_py=conf.train_py,
export_py=conf.export_py,
eval_py=conf.eval_py)
b = b.with_training_data_uri(conf.training_data_uri)
b = b.with_training_output_uri(conf.training_output_uri)
b = b.with_model_uri(conf.model_uri)
Expand Down Expand Up @@ -248,7 +258,8 @@ def _applicable_tasks(self):
def _process_task(self):
return self.with_config(
{
'trainCropSize': [self.task.chip_size, self.task.chip_size]
'trainCropSize': [self.task.chip_size, self.task.chip_size],
'evalCropSize': [self.task.chip_size, self.task.chip_size]
},
ignore_missing_keys=True)

Expand Down Expand Up @@ -327,25 +338,28 @@ def with_train_options(self,
train_restart_dir=None,
sync_interval=600,
do_monitoring=True,
replace_model=False):
replace_model=False,
do_eval=False):
"""Sets the train options for this backend.
Args:
sync_interval: How often to sync output of training to the cloud
(in seconds).
do_monitoring: Run process to monitor training (eg. Tensorboard)
replace_model: Replace the model checkpoint if exists.
If false, this will continue training from
checkpoing if exists, if the backend allows for this.
do_eval: Boolean determining whether to run the eval
script.
"""
b = deepcopy(self)
b.config['train_options'] = TFDeeplabConfig.TrainOptions(
train_restart_dir=train_restart_dir,
sync_interval=sync_interval,
do_monitoring=do_monitoring,
replace_model=replace_model)
replace_model=replace_model,
do_eval=do_eval)

return b.with_config(
{
Expand Down Expand Up @@ -397,9 +411,10 @@ def with_training_output_uri(self, training_output_uri):

def with_script_locations(self,
train_py=DEFAULT_SCRIPT_TRAIN,
export_py=DEFAULT_SCRIPT_EXPORT):
export_py=DEFAULT_SCRIPT_EXPORT,
eval_py=DEFAULT_SCRIPT_EVAL):
script_locs = TFDeeplabConfig.ScriptLocations(
train_py=train_py, export_py=export_py)
train_py=train_py, export_py=export_py, eval_py=eval_py)
b = deepcopy(self)
b.config['script_locations'] = script_locs
return b
2 changes: 2 additions & 0 deletions rastervision/protos/backend.proto
Expand Up @@ -42,11 +42,13 @@ message BackendConfig {

message TFDeeplabConfig {
optional string train_py = 1 [default="/opt/tf-models/deeplab/train.py"];
optional string eval_py = 14 [default="/opt/tf-models/deeplab/eval.py"];
optional string export_py = 2 [default="/opt/tf-models/deeplab/export_model.py"];

optional string train_restart_dir = 3;
optional int32 sync_interval = 4 [default=600];
optional bool do_monitoring = 5 [default=true];
optional bool do_eval = 13 [default=false];
optional bool replace_model = 6 [default=false];

optional string training_data_uri = 7;
Expand Down

0 comments on commit 66a1ef5

Please sign in to comment.