Skip to content

Commit

Permalink
add generator apis for tensorflow model wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
jfischer committed Dec 1, 2019
1 parent cc3d821 commit 948fe91
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 7 deletions.
147 changes: 140 additions & 7 deletions dataworkspaces/kits/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def call(self, inputs):
else:
USING_TENSORFLOW2=False
import tensorflow.keras.optimizers as optimizers
import tensorflow.keras.utils as kerasutils
if USING_TENSORFLOW2:
import tensorflow.keras.losses as losses
else:
Expand All @@ -155,6 +156,52 @@ def _verify_eager_if_dataset(x, y, api_resource):
"supported with TensorFlow 1.x.")


def _wrap_generator(wrapped, hash_state):
"""Return a generator such that it hashes
the values returned for each iterator
"""
def wrapper():
for v in wrapped:
if len(v)==2:
(inputs, targets) = v
sample_weights = None
else:
(inputs, targets, sample_weights) = v
_add_to_hash(inputs, hash_state)
_add_to_hash(targets, hash_state)
if sample_weights is not None:
_add_to_hash(sample_weights, hash_state)
yield v
return wrapper()

class _TfKerasSequenceWrapper(kerasutils.Sequence):
def __init__(self, wrapped, hash_state):
self.wrapped = wrapped
self.hash_state = hash_state

def __getitem__(self, idx):
v = self.wrapped.__getitem__(idx)
if len(v)==2:
(inputs, targets) = v
sample_weights = None
else:
(inputs, targets, sample_weights) = v
_add_to_hash(inputs, self.hash_state)
_add_to_hash(targets, self.hash_state)
if sample_weights is not None:
_add_to_hash(sample_weights)
return v

def __len__(self):
return self.wrapped.__len__()

def __iter__(self):
return _wrap_generator(self.wrapped, self.hash_state)

def on_epoch_end(self):
return self.on_epoch_end()


def add_lineage_to_keras_model_class(Cls:type,
input_resource:Optional[Union[str, ResourceRef]]=None,
results_resource:Optional[Union[str, ResourceRef]]=None,
Expand Down Expand Up @@ -184,9 +231,15 @@ def add_lineage_to_keras_model_class(Cls:type,
* :func:`~fit` - captures the ``epochs`` and ``batch_size`` parameter values;
if input is an API resource, capture hash values of training data, otherwise capture
input resource name.
* :func:`~evaluate` - captures the ``batch_size`` paramerter value; if input is an
* :func:`~fit_generator` - captues the ``epochs`` and ``steps_per_epoch`` parameter
values; if input is an API resource, wraps the generator and captures the hashes
of returned values from the generator as it is iterated through.
* :func:`~evaluate` - captures the ``batch_size`` parameter value; if input is an
API resource, capture hash values of test data, otherwise capture input resource
name; capture metrics and write them to results resource.
* :func:`~evaluate_generator` - captures the ``steps`` parameter value; if input is
an API resource, wraps the generator and captures the hashes of returned values
from the generator as it is iterated through.
"""
if hasattr(Cls, '_dws_model_wrap') and Cls._dws_model_wrap is True: # type: ignore
print("dws>> %s or a superclass is already wrapped" % Cls.__name__)
Expand Down Expand Up @@ -221,13 +274,13 @@ def compile(self, optimizer,
target_tensors, distribute, **kwargs)
def fit(self, x,y=None, **kwargs):
if 'epochs' in kwargs:
self._dws_state.lineage.add_param('epochs', kwargs['epochs'])
self._dws_state.lineage.add_param('fit.epochs', kwargs['epochs'])
else:
self._dws_state.lineage.add_param('epochs', 1)
self._dws_state.lineage.add_param('fit.epochs', 1)
if 'batch_size' in kwargs:
self._dws_state.lineage.add_param('fit_batch_size', kwargs['batch_size'])
self._dws_state.lineage.add_param('fit.batch_size', kwargs['batch_size'])
else:
self._dws_state.lineage.add_param('fit_batch_size', None)
self._dws_state.lineage.add_param('fit.batch_size', None)
api_resource = self._dws_state.find_input_resources_and_return_if_api(x, y)
if api_resource is not None:
_verify_eager_if_dataset(x, y, api_resource)
Expand All @@ -236,13 +289,58 @@ def fit(self, x,y=None, **kwargs):
_add_to_hash(x, hash_state)
if y is not None:
_add_to_hash(y, hash_state)
api_resource.save_current_hash() # in case we evaluate in a separate process
return super().fit(x, y, **kwargs)

def fit_generator(self,
generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
validation_freq=1,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0):
self._dws_state.lineage.add_param('fit_generator.epochs', epochs)
self._dws_state.lineage.add_param('fit_generator.steps_per_epoch', steps_per_epoch)
api_resource = self._dws_state.find_input_resources_and_return_if_api(generator)
if api_resource is not None:
# wrap the generator to capture each entry as it is returned
api_resource.init_hash_state()
hash_state = api_resource.get_hash_state()
if isinstance(generator, kerasutils.Sequence):
generator = _TfKerasSequenceWrapper(generator, hash_state)
else:
generator = _wrap_generator(generator, hash_state)
results = super().fit_generator(generator,
steps_per_epoch,
epochs,
verbose,
callbacks,
validation_data,
validation_steps,
validation_freq,
class_weight,
max_queue_size,
workers,
use_multiprocessing,
shuffle,
initial_epoch)
if api_resource is not None:
api_resource.save_current_hash()
return results

def evaluate(self, x, y=None, **kwargs):
if 'batch_size' in kwargs:
self._dws_state.lineage.add_param('evaluate_batch_size', kwargs['batch_size'])
self._dws_state.lineage.add_param('evaluate.batch_size', kwargs['batch_size'])
else:
self._dws_state.lineage.add_param('evaluate_batch_size', None)
self._dws_state.lineage.add_param('evaluate.batch_size', None)
api_resource = self._dws_state.find_input_resources_and_return_if_api(x, y)
if api_resource is not None:
_verify_eager_if_dataset(x, y, api_resource)
Expand All @@ -258,6 +356,41 @@ def evaluate(self, x, y=None, **kwargs):
self._dws_state.write_metrics_and_complete({n:v for (n, v) in
zip(self.metrics_names, results)})
return results

def evaluate_generator(self,
generator,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0):
self._dws_state.lineage.add_param('evaluate_generator.steps', steps)
api_resource = self._dws_state.find_input_resources_and_return_if_api(generator)
if api_resource is not None:
# wrap the generator to capture each entry as it is returned
api_resource.dup_hash_state()
hash_state = api_resource.get_hash_state()
if isinstance(generator, kerasutils.Sequence):
generator = _TfKerasSequenceWrapper(generator, hash_state)
else:
generator = _wrap_generator(generator, hash_state)
results = super().evaluate_generator(generator,
steps,
callbacks,
max_queue_size,
workers,
use_multiprocessing,
verbose)
if api_resource is not None:
api_resource.save_current_hash()
api_resource.pop_hash_state()
assert len(results)==len(self.metrics_names)
self._dws_state.write_metrics_and_complete({n:v for (n, v) in
zip(self.metrics_names, results)})
return results


WrappedModel.__name__ = Cls.__name__ # this is to fake things out for the reporting
if workspace.verbose:
print("dws>> Wrapped model class %s" % Cls.__name__)
Expand Down
2 changes: 2 additions & 0 deletions dataworkspaces/kits/wrapper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _add_to_hash(array_data, hash_state):
else:
raise Exception("Tensor type %s is not in eager mode, cannot convert to numpy, value was: %s"%
(type(array_data), repr(array_data)))
elif isinstance(array_data, np.uint8) or isinstance(array_data, np.int8) or isinstance(array_data, np.int32) or isinstance(array_data, np.int64):
hash_state.update(bytes(int(array_data)))
else:
raise Exception("Unable to hash data type %s, data was: %s"%
(type(array_data), array_data))
Expand Down
103 changes: 103 additions & 0 deletions tests/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from os.path import exists, join
import json
import functools
import inspect

from utils_for_tests import SimpleCase, WS_DIR

Expand Down Expand Up @@ -31,6 +32,13 @@

from dataworkspaces.kits.wrapper_utils import NotSupportedError

def generator_from_arrays(x, y):
assert len(x)==len(y)
# keras expects the same number of dimensions, so, we reshape to add one more
old_shape = x[0].shape
new_shape = (1, old_shape[0], old_shape[1])
for i in range(len(y)):
yield(x[i].reshape(new_shape), y[i].reshape((1,1)))

class TestTensorflowKit(SimpleCase):

Expand Down Expand Up @@ -199,6 +207,101 @@ def normalize_numeric_data(data, mean, std):
self.assertAlmostEqual(test_loss, data['metrics']['loss'])
self._take_snapshot()

@unittest.skipUnless(TF_INSTALLED, "Tensorflow not available")
def test_wrapper_for_generators(self):
"""This test follows the basic classification tutorial, modified for using
the fit_generator() and eval_generator() methods.
"""
import tensorflow as tf
import tensorflow.keras as keras
self._setup_initial_repo(git_resources='results', api_resources='fashion-mnist-data')
from dataworkspaces.kits.tensorflow import add_lineage_to_keras_model_class
keras.Sequential = add_lineage_to_keras_model_class(keras.Sequential,
input_resource='fashion-mnist-data',
verbose=True,
workspace_dir=WS_DIR)
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
g = generator_from_arrays(train_images, train_labels)
self.assertTrue(inspect.isgenerator(g))
model.fit_generator(g, epochs=5, steps_per_epoch=2)
g2 = generator_from_arrays(test_images, test_labels)
test_loss, test_acc = model.evaluate_generator(g2, steps=len(test_labels), verbose=2)
print("test accuracy: %s" % test_acc)
results_file = join(WS_DIR, 'results/results.json')
self.assertTrue(exists(results_file), "missing file %s" % results_file)
with open(results_file, 'r') as f:
data = json.load(f)
self.assertAlmostEqual(test_acc, data['metrics']['accuracy' if TF_VERSION==2 else 'acc'])
self.assertAlmostEqual(test_loss, data['metrics']['loss'])
self._take_snapshot()

@unittest.skipUnless(TF_INSTALLED, "Tensorflow not available")
def test_wrapper_for_keras_sequence(self):
"""This test follows the basic classification tutorial, modified for using
the fit_generator() and eval_generator() methods.
"""
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.utils as kerasutils
class KSequence(kerasutils.Sequence):
def __init__(self, x, y):
assert len(x)==len(y)
self.x = x
self.y = y
old_shape = x[0].shape
self.new_shape = (1, old_shape[0], old_shape[1])

def __iter__(self):
return generator_from_arrays(self.x, self.y)

def __getitem__(self, idx):
return (self.x[idx].reshape(self.new_shape), self.y[idx].reshape((1,1)))

def __len__(self):
return len(self.y)

self._setup_initial_repo(git_resources='results', api_resources='fashion-mnist-data')
from dataworkspaces.kits.tensorflow import add_lineage_to_keras_model_class
keras.Sequential = add_lineage_to_keras_model_class(keras.Sequential,
input_resource='fashion-mnist-data',
verbose=True,
workspace_dir=WS_DIR)
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
g = KSequence(train_images, train_labels)
model.fit_generator(g, epochs=5, steps_per_epoch=2)
g2 = KSequence(test_images, test_labels)
test_loss, test_acc = model.evaluate_generator(g2, steps=len(test_labels), verbose=2)
print("test accuracy: %s" % test_acc)
results_file = join(WS_DIR, 'results/results.json')
self.assertTrue(exists(results_file), "missing file %s" % results_file)
with open(results_file, 'r') as f:
data = json.load(f)
self.assertAlmostEqual(test_acc, data['metrics']['accuracy' if TF_VERSION==2 else 'acc'])
self.assertAlmostEqual(test_loss, data['metrics']['loss'])
self._take_snapshot()

if __name__ == '__main__':
unittest.main()

0 comments on commit 948fe91

Please sign in to comment.