Skip to content

Commit

Permalink
add checkpoint support for tensorflow; fix some bugs in the scratch d…
Browse files Browse the repository at this point in the history
…irectory handling
  • Loading branch information
jfischer committed Jan 5, 2020
1 parent 3e98a09 commit e45f408
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 17 deletions.
8 changes: 6 additions & 2 deletions dataworkspaces/backends/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,6 @@ def init_workspace(workspace_name:str, dws_version:str, # type: ignore
f.write("%s\n" % basename(LOCAL_PARAMS_PATH))
f.write("%s\n" % basename(RESOURCE_LOCAL_PARAMS_PATH))
f.write("current_lineage/\n")
if scratch_dir_gitignore is not None:
f.write(scratch_dir_gitignore+"\n")
if exists(join(workspace_dir, '.git')):
click.echo("%s is already a git repository, will just add to it"%
workspace_dir)
Expand All @@ -534,7 +532,13 @@ def init_workspace(workspace_name:str, dws_version:str, # type: ignore
git_add(workspace_dir,
[CONFIG_FILE_PATH, RESOURCES_FILE_PATH, GIT_IGNORE_FILE_PATH],
verbose=verbose)
if scratch_dir_gitignore is not None:
# add the scratch directory's gitignore entry to the top level of
# the repo, not the .gitignore within .dataworkspace
ensure_entry_in_gitignore(workspace_dir, '.gitignore', scratch_dir_gitignore,
commit=False, verbose=verbose)
commit_changes_in_repo(workspace_dir, "dws init", verbose=verbose)

if not isdir(abs_scratch_dir):
if verbose:
print("Creating scratch directory %s" % abs_scratch_dir)
Expand Down
179 changes: 174 additions & 5 deletions dataworkspaces/kits/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,12 @@ def call(self, inputs):
**API**
"""
from typing import Optional, Union, List
from typing import Optional, Union, List, Dict, cast, NamedTuple
assert List
import os
from os.path import join, isdir, exists, basename
import re
import glob

import tensorflow
if tensorflow.__version__.startswith('2.'): # type: ignore
Expand All @@ -133,14 +137,17 @@ def call(self, inputs):
USING_TENSORFLOW2=False
import tensorflow.keras.optimizers as optimizers
import tensorflow.keras.utils as kerasutils
from tensorflow.keras.callbacks import ModelCheckpoint
if USING_TENSORFLOW2:
import tensorflow.keras.losses as losses
else:
import tensorflow.losses as losses

from dataworkspaces.workspace import find_and_load_workspace, ResourceRef
from dataworkspaces.workspace import find_and_load_workspace, ResourceRef,\
ResourceRoles, FileResourceMixin
from dataworkspaces.errors import ConfigurationError
from dataworkspaces.kits.wrapper_utils import _DwsModelState, _add_to_hash,\
NotSupportedError
NotSupportedError, _find_resource


def _verify_eager_if_dataset(x, y, api_resource):
Expand Down Expand Up @@ -201,12 +208,150 @@ def __iter__(self):
def on_epoch_end(self):
return self.on_epoch_end()

class DwsModelCheckpoint(ModelCheckpoint):
"""
Subclass of tf.keras.callbacks.ModelCheckpoint which will save checkpoints
to the workspace's stratch space and then move the most recent/best checkpoint
to the results directory at the end of the run.
You can instantiate this class directly and pass it to the ``callbacks``
parameter of the model's ``fit()`` method::
model.fit(train_images, train_labels, epochs=10,
callbacks=[DwsModelCheckpoint('fashion', monitor='loss', save_best_only=True)])
You can also pass :class:`~CheckpointConfig` instance to the
:func:`~add_lineage_to_keras_model_class` wrapper function.
"""
def __init__(self, model_name:str, monitor:str='val_loss',
save_best_only:bool=False, mode:str='auto',
save_freq:Union[str, int]='epoch',
results_resource:Optional[Union[str, ResourceRef]]=None,
workspace_dir:Optional[str]=None,
verbose:Union[int,bool]=0):
"""
model_name is used to create the checkpoint filenames. The checkpoints
will be saved as MODEL_NAME_{epoch}.
Currently, only supports save_weights_only option.
verbose can be either 0,1 in the style of tensorflow or a True,False
in the style of Data Workspaces.
"""
self.dws_model_name = model_name
if verbose==0 or verbose==False:
tf_verbose = 0
dws_verbose = False
else:
tf_verbose = 1
dws_verbose = True

self.workspace = find_and_load_workspace(batch=True, verbose=dws_verbose,
uri_or_local_path=workspace_dir)

results_ref= _find_resource(self.workspace, ResourceRoles.RESULTS,
results_resource)
self.results_resource = self.workspace.get_resource(results_ref.name)
if not isinstance(self.results_resource, FileResourceMixin):
raise ConfigurationError("Resource %s is not a file-based resource"%
results_ref.name)
self.results_subdir = results_ref.subpath # type: Optional[str]
scratch_dir = self.workspace.get_scratch_directory()
assert isdir(scratch_dir), "missing scratch directory %s"%scratch_dir
self.dws_checkpoint_path = join(scratch_dir, 'checkpoints') # type: str
if not isdir(self.dws_checkpoint_path):
os.mkdir(self.dws_checkpoint_path)
self.checkpoint_filepath_template = join(self.dws_checkpoint_path, model_name+'_{epoch}')
super().__init__(filepath=self.checkpoint_filepath_template,
monitor=monitor, save_best_only=save_best_only, mode=mode,
save_freq=save_freq, save_weights_only=True,
verbose=tf_verbose)

def on_train_begin(self, logs:Optional[Dict]=None):
files_to_delete = [] # type: List[str]
files_to_delete.extend(glob.glob(join(self.dws_checkpoint_path,
self.dws_model_name+'_*[0-9].index')))
files_to_delete.extend(glob.glob(join(self.dws_checkpoint_path,
self.dws_model_name+'_*[0-9].data-*[0-9]-of-*[0-9]')))
checkpoint_metadata_file = join(self.dws_checkpoint_path, 'checkpoint')
if exists(checkpoint_metadata_file):
files_to_delete.append(checkpoint_metadata_file)
for f in files_to_delete:
os.remove(f)
print("dws> Removed %d old checkpoint files for model %s ahead of training"%
(len(files_to_delete), self.dws_model_name))
return super().on_train_begin(logs)

def on_train_end(self, logs:Optional[Dict]=None):
checkpoint_metadata_file = join(self.dws_checkpoint_path, 'checkpoint')
assert exists(checkpoint_metadata_file), \
"Missing checkpoint metadata file %s"%checkpoint_metadata_file
# find the checkpoint that we want to save
with open(checkpoint_metadata_file, 'r') as f:
MODEL_CHECKPOINT_PATH=re.compile('^'+re.escape('model_checkpoint_path:')+r'\s+"('+
re.escape(self.dws_model_name+'_')+r'\d+)"$')
checkpoint_base = None
for line in f:
mo = MODEL_CHECKPOINT_PATH.match(line.rstrip())
if mo is not None:
checkpoint_base = mo.group(1)
break
assert checkpoint_base is not None,\
"Did not find model checkpoint path in %s"%checkpoint_metadata_file
copy_files=[] # type: List[str]
copy_files.append(join(self.dws_checkpoint_path, checkpoint_base+'.index'))
copy_files.extend(glob.glob(join(self.dws_checkpoint_path, checkpoint_base+'.data-*[0-9]-of-*[0-9]')))
copy_files.append(join(self.dws_checkpoint_path, 'checkpoint')) # copy index file to make it easy to load checkpoint
for src_file in copy_files:
if self.results_subdir is not None:
dest_path = join(self.results_subdir, basename(src_file))
else:
dest_path = basename(src_file)
cast(FileResourceMixin, self.results_resource).upload_file(src_file,
dest_path)
if self.results_subdir is not None:
print("dws> Copied checkpoint %s to resource %s:%s"%
(checkpoint_base, self.results_resource.name,
self.results_subdir))
else:
print("dws> Copied checkpoint %s to resource %s"%
(checkpoint_base, self.results_resource.name))

return super().on_train_end(logs)


class CheckpointConfig(NamedTuple):
"""Configuration for checkpoints, to be passed as a parameter
to :func:`~add_lineage_to_keras_model_class`, instead of
directly instantiating :class:`~DwsModelChecpoint`.
The checkpoints are initially written under the workspace's
scratch space. At the end of training, the best checkpoint is
copied to the results resource.
The configuration fields are:
* ``model_name`` - name of the model to use in checkpoint files
* ``monitor`` - metric to monitor - defaults to val_loss
* ``save_best_only`` - if True, only checkpoints better than the
previous are kept.
* ``mode`` - how to determine whether a metric is the "best" - auto, min, or max
* ``save_freq`` - 'epoch' or an interger
"""
model_name : str
monitor: str = 'val_loss'
save_best_only: bool = False
mode: str = 'auto'
save_freq:Union[str, int] = 'epoch'


def add_lineage_to_keras_model_class(Cls:type,
input_resource:Optional[Union[str, ResourceRef]]=None,
results_resource:Optional[Union[str, ResourceRef]]=None,
workspace_dir=None,
verbose=False):
workspace_dir:Optional[str]=None,
checkpoint_config:Optional[CheckpointConfig]=None,
verbose:bool=False):
"""This function wraps a Keras model class with a subclass that overwrites
key methods to make calls to the data lineage API.
Expand All @@ -222,6 +367,8 @@ def add_lineage_to_keras_model_class(Cls:type,
if not specified, will try to infer from the workspace.
* ``workspace-dir`` -- Optional directory specifying the workspace. Usually can be
inferred from the current directory.
* ``checkpoint_config`` -- Optional instance of :class:`~CheckpointConfig`, which
is used to enable checkpointing on fit and fit_generator()
* ``verbose`` -- If True, print extra debugging information.
The following methods are wrapped:
Expand Down Expand Up @@ -252,6 +399,17 @@ class WrappedModel(Cls): # type: ignore
def __init__(self,*args,**kwargs):
super().__init__(*args, **kwargs)
self._dws_state = _DwsModelState(workspace, input_resource, results_resource)
if checkpoint_config is not None:
self.checkpoint_cb = DwsModelCheckpoint(checkpoint_config.model_name,
monitor=checkpoint_config.monitor,
save_best_only=checkpoint_config.save_best_only,
mode=checkpoint_config.mode,
save_freq=checkpoint_config.save_freq,
results_resource=results_resource,
workspace_dir=workspace_dir,
verbose=verbose)
else:
self.checkpoint_cb = None
def compile(self, optimizer,
loss=None,
metrics=None,
Expand Down Expand Up @@ -290,6 +448,11 @@ def fit(self, x,y=None, **kwargs):
if y is not None:
_add_to_hash(y, hash_state)
api_resource.save_current_hash() # in case we evaluate in a separate process
if self.checkpoint_cb:
if 'callbacks' in kwargs:
kwargs['callbacks'].append(self.checkpoint_cb)
else:
kwargs['callbacks'] = [self.checkpoint_cb,]
return super().fit(x, y, **kwargs)

def fit_generator(self,
Expand Down Expand Up @@ -318,6 +481,11 @@ def fit_generator(self,
generator = _TfKerasSequenceWrapper(generator, hash_state)
else:
generator = _wrap_generator(generator, hash_state)
if self.checkpoint_cb:
if callbacks is not None:
callbacks.append(self.checkpoint_cb)
else:
callbacks = [self.checkpoint_cb,]
results = super().fit_generator(generator,
steps_per_epoch,
epochs,
Expand Down Expand Up @@ -395,3 +563,4 @@ def evaluate_generator(self,
if workspace.verbose:
print("dws>> Wrapped model class %s" % Cls.__name__)
return WrappedModel

25 changes: 25 additions & 0 deletions dataworkspaces/resources/git_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from os.path import realpath, basename, isdir, join, dirname, exists,\
abspath, expanduser, commonpath, isabs
import shutil
import stat
import click
import json
Expand Down Expand Up @@ -101,6 +102,29 @@ def does_subpath_exist(self, subpath:str, must_be_file:bool=False,
return does_subpath_exist(self.local_path, subpath, must_be_file,
must_be_directory)

def upload_file(self, src_local_path:str,
rel_dest_path:str) -> None:
"""Copy a local file to the specified path in the
resource. This may be a local copy or an upload, depending
on the resource implmentation
"""
abs_dest_path = join(self.local_path, rel_dest_path)
parent_dir = dirname(abs_dest_path)
if not exists(src_local_path):
raise ConfigurationError("Source file %s does not exist"%src_local_path)
if not exists(parent_dir):
os.makedirs(parent_dir)
shutil.copyfile(src_local_path, abs_dest_path)
rel_to_repo_path = get_subpath_from_absolute(self.repo_dir, abs_dest_path)
assert rel_to_repo_path is not None
call_subprocess([GIT_EXE_PATH, 'add', rel_to_repo_path],
cwd=self.repo_dir, verbose=self.workspace.verbose)
call_subprocess([GIT_EXE_PATH, 'commit',
'-m', "Added %s" % rel_to_repo_path],
cwd=self.repo_dir, verbose=self.workspace.verbose)
if self.workspace.verbose:
click.echo("%s: Copied file to %s" % (self.name, rel_dest_path))

def read_results_file(self, subpath:str) -> Union[JSONDict,JSONList]:
"""Read and parse json results data from the specified path
in the resource. If the path does not exist or is not a file
Expand Down Expand Up @@ -550,6 +574,7 @@ def add_results_file(self, data:Union[JSONDict,JSONList], rel_dest_path:str) ->
'-m', "Added %s" % rel_to_repo_path],
cwd=self.workspace_dir, verbose=self.workspace.verbose)


def push_precheck(self):
if not exists(self.local_path):
raise ConfigurationError("Missing directory %s for resource %s"%
Expand Down
15 changes: 15 additions & 0 deletions dataworkspaces/resources/local_file_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ def add_results_file(self, data:Union[JSONDict,JSONList], rel_dest_path:str) ->
with open(abs_dest_path, 'w') as f:
json.dump(data, f, indent=2)


def upload_file(self, local_path:str,
rel_dest_path:str) -> None:
"""Copy a local file to the specified path in the
resource. This may be a local copy or an upload, depending
on the resource implmentation
"""
abs_dest_path = os.path.join(self.local_path, rel_dest_path)
parent_dir = os.path.dirname(abs_dest_path)
if not exists(local_path):
raise ConfigurationError("Source file %s does not exist." % local_path)
if not os.path.isdir(parent_dir):
os.makedirs(parent_dir)
shutil.copyfile(local_path, rel_dest_path)

def does_subpath_exist(self, subpath:str, must_be_file:bool=False,
must_be_directory:bool=False) -> bool:
return does_subpath_exist(self.local_path, subpath, must_be_file,
Expand Down
14 changes: 14 additions & 0 deletions dataworkspaces/resources/rclone_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,20 @@ def read_results_file(self, subpath:str) -> Union[JSONDict,JSONList]:
raise ConfigurationError("Parse error when reading %s in resource %s"
%(subpath, self.name)) from e

def upload_file(self, src_local_path:str,
rel_dest_path:str) -> None:
"""Copy a local file to the specified path in the
resource. This may be a local copy or an upload, depending
on the resource implmentation
"""
abs_dest_path = os.path.join(self.local_path, rel_dest_path)
parent_dir = os.path.dirname(abs_dest_path)
if not os.path.exists(src_local_path):
raise ConfigurationError("Source file %s does not exist." % src_local_path)
if not os.path.isdir(parent_dir):
os.makedirs(parent_dir)
shutil.copyfile(src_local_path, rel_dest_path)

def get_local_params(self) -> JSONDict:
return {} # TODO: local filepath can override global path

Expand Down
5 changes: 3 additions & 2 deletions dataworkspaces/utils/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def init_scratch_directory(scratch_dir:str, workspace_dir:str,
if rel_scratch_dir.startswith('./'):
scratch_dir_gitignore = rel_scratch_dir[1:]
else:
scratch_dir_gitignore = '/'
scratch_dir_gitignore = '/' + rel_scratch_dir
else:
local_params[LOCAL_SCRATCH_DIRECTORY] = abs_scratch_dir
return (abs_scratch_dir, scratch_dir_gitignore)
Expand Down Expand Up @@ -233,7 +233,8 @@ def get_scratch_directory(workspace_dir:str, global_params:Dict[str,Any],
in either, print a warning and return None.
"""
if SCRATCH_DIRECTORY in global_params:
return join(workspace_dir, global_params[SCRATCH_DIRECTORY])
# normalize the path to remove any "." in the path
return abspath(join(workspace_dir, global_params[SCRATCH_DIRECTORY]))
elif LOCAL_SCRATCH_DIRECTORY in local_params:
return local_params[LOCAL_SCRATCH_DIRECTORY]
else:
Expand Down
9 changes: 9 additions & 0 deletions dataworkspaces/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,15 @@ def add_results_file(self, data:Union[JSONDict,JSONList],
"""
pass

@abstractmethod
def upload_file(self, src_local_path:str,
rel_dest_path:str) -> None:
"""Copy a local file to the specified path in the
resource. This may be a local copy or an upload, depending
on the resource implmentation
"""
pass

@abstractmethod
def read_results_file(self, subpath:str) -> Union[JSONDict,JSONList]:
"""Read and parse json results data from the specified path
Expand Down

0 comments on commit e45f408

Please sign in to comment.