Skip to content

Commit

Permalink
Merge pull request #461 from azavea/rde/feature/archive-checkpoints
Browse files Browse the repository at this point in the history
Fine tune checkpoint output
  • Loading branch information
lossyrob committed Oct 7, 2018
2 parents 12550e2 + 8ffe64a commit 499effd
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 207 deletions.
102 changes: 0 additions & 102 deletions src/integration_tests/chip_classification_tests/configs/workflow.json

This file was deleted.

91 changes: 0 additions & 91 deletions src/integration_tests/object_detection_tests/configs/workflow.json

This file was deleted.

21 changes: 21 additions & 0 deletions src/rastervision/backend/tf_deeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
sync_to_dir, sync_from_dir)
from rastervision.utils.misc import (numpy_to_png, png_to_numpy, save_img)
from rastervision.data.label_source.utils import color_to_integer
from rastervision.rv_config import RVConfig

FROZEN_INFERENCE_GRAPH = 'model'
INPUT_TENSOR_NAME = 'ImageTensor:0'
Expand Down Expand Up @@ -610,6 +611,26 @@ def train(self, tmp_dir: str) -> None:
terminate_at_exit(export_process)
export_process.wait()

# Package up the model files for usage as fine tuning checkpoints
fine_tune_checkpoint_name = self.backend_config.fine_tune_checkpoint_name
latest_checkpoints = get_latest_checkpoint(train_logdir_local)
model_checkpoint_files = glob.glob(
'{}*'.format(latest_checkpoints))
inference_graph_path = os.path.join(train_logdir_local, 'model')

with RVConfig.get_tmp_dir() as tmp_dir:
model_dir = os.path.join(tmp_dir, fine_tune_checkpoint_name)
make_dir(model_dir)
model_tar = os.path.join(
train_logdir_local,
'{}.tar.gz'.format(fine_tune_checkpoint_name))
shutil.copy(inference_graph_path,
'{}/frozen_inference_graph.pb'.format(model_dir))
for path in model_checkpoint_files:
shutil.copy(path, model_dir)
with tarfile.open(model_tar, 'w:gz') as tar:
tar.add(model_dir, arcname=os.path.basename(model_dir))

# Perform final sync
sync_to_dir(train_logdir_local, train_logdir, delete=False)

Expand Down
23 changes: 19 additions & 4 deletions src/rastervision/backend/tf_deeplab_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(self,
debug=False,
training_data_uri=None,
training_output_uri=None,
model_uri=None):
model_uri=None,
fine_tune_checkpoint_name=None):
if train_options is None:
train_options = TFDeeplabConfig.TrainOptions()
if script_locations is None:
Expand All @@ -60,6 +61,7 @@ def __init__(self,
self.training_data_uri = training_data_uri
self.training_output_uri = training_output_uri
self.model_uri = model_uri
self.fine_tune_checkpoint_name = fine_tune_checkpoint_name

def create_backend(self, task_config):
return TFDeeplab(self, task_config)
Expand All @@ -76,6 +78,7 @@ def to_proto(self):
'training_data_uri': self.training_data_uri,
'training_output_uri': self.training_output_uri,
'model_uri': self.model_uri,
'fine_tune_checkpoint_name': self.fine_tune_checkpoint_name,
'tfdl_config': self.tfdl_config
}

Expand Down Expand Up @@ -110,6 +113,11 @@ def preprocess_command(self, command_type, experiment_config,

conf.model_uri = os.path.join(conf.training_output_uri, 'model')
io_def.add_output(conf.model_uri)

# Set the fine tune checkpoint name to the experiment id
if not conf.fine_tune_checkpoint_name:
conf.fine_tune_checkpoint_name = experiment_config.id
io_def.add_output(conf.fine_tune_checkpoint_name)
if command_type in [rv.PREDICT, rv.BUNDLE]:
if not conf.model_uri:
io_def.add_missing('Missing model_uri.')
Expand Down Expand Up @@ -148,7 +156,8 @@ def __init__(self, prev=None):
'debug': prev.debug,
'training_data_uri': prev.training_data_uri,
'training_output_uri': prev.training_output_uri,
'model_uri': prev.model_uri
'model_uri': prev.model_uri,
'fine_tune_checkpoint_name': prev.fine_tune_checkpoint_name
}
super().__init__(rv.TF_DEEPLAB, TFDeeplabConfig, config, prev)
self.config_mods = []
Expand All @@ -161,8 +170,6 @@ def from_proto(self, msg):
# assume the task has already been set and do not
# require it during validation.
b.require_task = False
if self.config.get('pretrained_model_uri'):
b = b.with_pretrained_model_uri(self.config.pretrained_model_uri)
b = b.with_train_options(
train_restart_dir=conf.train_restart_dir,
sync_interval=conf.sync_interval,
Expand All @@ -173,6 +180,7 @@ def from_proto(self, msg):
b = b.with_training_data_uri(conf.training_data_uri)
b = b.with_training_output_uri(conf.training_output_uri)
b = b.with_model_uri(conf.model_uri)
b = b.with_fine_tune_checkpoint_name(conf.fine_tune_checkpoint_name)
b = b.with_debug(conf.debug)
b = b.with_template(json_format.MessageToDict(conf.tfdl_config))
return b
Expand Down Expand Up @@ -315,6 +323,13 @@ def with_model_uri(self, model_uri):
b.config['model_uri'] = model_uri
return b

def with_fine_tune_checkpoint_name(self, fine_tune_checkpoint_name):
"""Defines the name of the fine tune checkpoint that will
be created for this model after training."""
b = deepcopy(self)
b.config['fine_tune_checkpoint_name'] = fine_tune_checkpoint_name
return b

def with_training_data_uri(self, training_data_uri):
b = deepcopy(self)
b.config['training_data_uri'] = training_data_uri
Expand Down
27 changes: 25 additions & 2 deletions src/rastervision/backend/tf_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def get_last_checkpoint_path(train_root_dir):
def export_inference_graph(train_root_dir,
config_path,
output_dir,
fine_tune_checkpoint_name,
export_py=None):
export_py = (export_py or
'/opt/tf-models/object_detection/export_inference_graph.py')
Expand All @@ -302,8 +303,26 @@ def export_inference_graph(train_root_dir,
])
train_process.wait()

# Move frozen inference graph and clean up generated files.
inference_graph_path = join(output_dir, 'frozen_inference_graph.pb')

# Package up the model files for usage as fine tuning checkpoints
model_checkpoint_files = [
os.path.join(output_dir, fname) for fname in os.listdir(output_dir)
if fname.startswith('model.ckpt')
]
with RVConfig.get_tmp_dir() as tmp_dir:
model_dir = os.path.join(tmp_dir, fine_tune_checkpoint_name)
make_dir(model_dir)
model_tar = os.path.join(
output_dir, '{}.tar.gz'.format(fine_tune_checkpoint_name))
shutil.copy(inference_graph_path, model_dir)
for path in model_checkpoint_files:
shutil.copy(path, model_dir)
with tarfile.open(model_tar, 'w:gz') as tar:
tar.add(model_dir, arcname=os.path.basename(model_dir))

# Move frozen inference graph and clean up generated files.

output_path = join(output_dir, 'model')
shutil.move(inference_graph_path, output_path)
saved_model_dir = join(output_dir, 'saved_model')
Expand Down Expand Up @@ -693,7 +712,11 @@ def train(self, tmp_dir):
do_monitoring=self.config.train_options.do_monitoring)

export_inference_graph(
output_dir, local_config_path, output_dir, export_py=export_py)
output_dir,
local_config_path,
output_dir,
fine_tune_checkpoint_name=self.config.fine_tune_checkpoint_name,
export_py=export_py)

# Perform final sync
sync_to_dir(output_dir, self.config.training_output_uri)
Expand Down

0 comments on commit 499effd

Please sign in to comment.