Skip to content

Commit

Permalink
minor additions to tensorflow docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jfischer committed Oct 7, 2019
1 parent 57d1464 commit 632de8b
Showing 1 changed file with 18 additions and 38 deletions.
56 changes: 18 additions & 38 deletions dataworkspaces/kits/tensorflow1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
input data as it is passed to the model, and a *Results* resource to
keep the metrics. The only change we need to do to capture the lineage from
the model is to wrap the model's class, using
:func:`~add_lineage_to-keras_model_class`.
:func:`~add_lineage_to_keras_model_class`.
Here is the code::
Expand Down Expand Up @@ -70,14 +70,15 @@
@add_lineage_to_keras_model_class
class MyModel(keras.Model):
def __init__(self):
print("In MyModel init")
# The Tensorflow documentation tends to specify the class name
# when calling the superclass __init__ function. Don't do this --
# it breaks if you use class decorators!
#super(MyModel, self).__init__()
super().__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
print("Inputs: %s" % repr(inputs))
x1 = self.dense1(inputs)
return self.dense2(x1)
Expand All @@ -93,6 +94,8 @@ def call(self, inputs):
print('Test accuracy:', test_acc)
**API**
"""
import hashlib
from typing import Optional, Union
Expand All @@ -111,41 +114,6 @@ def call(self, inputs):
assert ApiResource # make pyflakes happy
from dataworkspaces.kits.jupyter import get_step_name_for_notebook

#from dataworkspaces.utils.patch_utils import patch_method
#def add_lineage_to_model_class(Cls):
# if hasattr(Cls, '_dws_model_wrap') and Cls._dws_model_wrap is True:
# print("%s is already wrapped" % Cls.__name__)
# return Cls # already wrapped
#def make_model__init__(original_method):
# def init(self, **kwargs):
# self.workspace = find_and_load_workspace(batch=True, verbose=False)
# self.hash_state = hashlib.sha1()
# original_method(self, **kwargs)
# return init
#patch_method(keras.Model, '__init__', make_model__init__)
#
#def make_model_fit(original_method):
# def fit(self, x, y, **kwargs):
# self.hash_state.update(x.data.tobytes())
# self.hash_state.update(y.data.tobytes())
# print("captured hash of training data: %s" % self.hash_state.hexdigest())
# return original_method(self, x, y, **kwargs)
# return fit
#patch_method(keras.Model, 'fit', make_model_fit)
#
#def make_model_evaluate(original_method):
# def evaluate(self, x, y, **kwargs):
# h = self.hash_state.copy()
# h.update(x.data.tobytes())
# h.update(y.data.tobytes())
# print("hash of input data is %s" % h.hexdigest())
# results = original_method(self, x, y, **kwargs)
# assert len(results)==len(self.metrics_names)
# metrics = {n:v for (n, v) in zip(self.metrics_names, results)}
# print("Metrics: %s" % metrics)
# return results
# return evaluate
#patch_method(keras.Model, 'evaluate', make_model_evaluate)

def _find_resource(workspace:Workspace, role:str,
name_or_ref:Optional[Union[str, ResourceRef]]=None) -> ResourceRef:
Expand Down Expand Up @@ -191,6 +159,18 @@ def add_lineage_to_keras_model_class(Cls:type,
results_resource:Optional[Union[str, ResourceRef]]=None):
"""This function wraps a Keras model class with a subclass that overwrites
key methods to make calls to the data lineage API.
The following methods are wrapped:
* :func:`~__init__` - loads the workspace and adds dws-specific class members
* :func:`~compile` - captures the ``optimizer`` and ``loss_function`` parameter values
* :func:`~fit` - captures the ``epochs`` and ``batch_size`` parameter values;
if input is an API resource, capture hash values of training data, otherwise capture
input resource name.
* :func:`~evaluate` - captures the ``batch_size`` paramerter value; if input is an
API resource, capture hash values of test data, otherwise capture input resource
name; capture metrics and write them to results resource.
"""
if hasattr(Cls, '_dws_model_wrap') and Cls._dws_model_wrap is True: # type: ignore
print("%s or a superclass is already wrapped" % Cls.__name__)
Expand Down

0 comments on commit 632de8b

Please sign in to comment.