Skip to content

Commit

Permalink
work on tensorflow wrapper, which now supports 1.x and 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jfischer committed Oct 31, 2019
1 parent c81909e commit 10d6fd4
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Integration with Tensorflow 1.x
"""Integration with Tensorflow 1.x and 2.0
This is an experimental API and subject to change.
Expand Down Expand Up @@ -100,8 +100,16 @@ def call(self, inputs):
from typing import Optional, Union, List
assert List

import tensorflow
if tensorflow.__version__.startswith('2.'): # type: ignore
USING_TENSORFLOW2=True
else:
USING_TENSORFLOW2=False
import tensorflow.keras.optimizers as optimizers
import tensorflow.losses as losses
if USING_TENSORFLOW2:
import tensorflow.keras.losses as losses
else:
import tensorflow.losses as losses

from dataworkspaces.workspace import find_and_load_workspace, ResourceRef
from dataworkspaces.kits.wrapper_utils import _DwsModelState, _add_to_hash
Expand All @@ -111,6 +119,7 @@ def call(self, inputs):
def add_lineage_to_keras_model_class(Cls:type,
input_resource:Optional[Union[str, ResourceRef]]=None,
results_resource:Optional[Union[str, ResourceRef]]=None,
workspace_dir=None,
verbose=False):
"""This function wraps a Keras model class with a subclass that overwrites
key methods to make calls to the data lineage API.
Expand All @@ -128,15 +137,15 @@ def add_lineage_to_keras_model_class(Cls:type,
"""
if hasattr(Cls, '_dws_model_wrap') and Cls._dws_model_wrap is True: # type: ignore
print("%s or a superclass is already wrapped" % Cls.__name__)
print("dws>> %s or a superclass is already wrapped" % Cls.__name__)
return Cls # already wrapped
workspace = find_and_load_workspace(batch=True, verbose=verbose)
workspace = find_and_load_workspace(batch=True, verbose=verbose,
uri_or_local_path=workspace_dir)

class WrappedModel(Cls): # type: ignore
_dws_model_wrap = True
def __init__(self,*args,**kwargs):
super().__init__(*args, **kwargs)
print("In wrapped init") # XXX
self._dws_state = _DwsModelState(workspace, input_resource, results_resource)
def compile(self, optimizer,
loss=None,
Expand All @@ -159,7 +168,6 @@ def compile(self, optimizer,
sample_weight_mode, weighted_metrics,
target_tensors, distribute, **kwargs)
def fit(self, x, y, **kwargs):
print("fit: in wrap of %s" % Cls.__name__) # XXX
if 'epochs' in kwargs:
self._dws_state.lineage.add_param('epochs', kwargs['epochs'])
else:
Expand Down Expand Up @@ -191,4 +199,7 @@ def evaluate(self, x, y, **kwargs):
self._dws_state.write_metrics_and_complete({n:v for (n, v) in
zip(self.metrics_names, results)})
return results
WrappedModel.__name__ = Cls.__name__ # this is to fake things out for the reporting
if workspace.verbose:
print("dws>> Wrapped model class %s" % Cls.__name__)
return WrappedModel
23 changes: 14 additions & 9 deletions dataworkspaces/kits/wrapper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,36 @@ def _metric_obj_to_json(v):

def _add_to_hash(array_data, hash_state):
if isinstance(array_data, np.ndarray):
hash_state.update(array_data.data.tobytes())
hash_state.update(array_data.data)
elif (pandas is not None) and isinstance(array_data, pandas.DataFrame):
for c in array_data.columns:
hash_state.update(array_data[c].to_numpy(copy=False).data.to_bytes())
hash_state.update(array_data[c].to_numpy(copy=False).data)
elif (pandas is not None) and isinstance(array_data, pandas.Series):
hash_state.update(array_data.to_numpy(copy=False).data.to_bytes())
hash_state.update(array_data.to_numpy(copy=False).data)
else:
raise Exception("Unable to hash data type %s, data was: %s"%
(type(array_data), array_data))


def _find_resource(workspace:Workspace, role:str,
name_or_ref:Optional[Union[str, ResourceRef]]=None) -> ResourceRef:
resource_names = [n for n in workspace.get_resource_names()]
if isinstance(name_or_ref, str):
if (not name_or_ref.startswith('./')) and (not name_or_ref.startswith('/')) and \
(name_or_ref in workspace.get_resource_names()):
(name_or_ref in resource_names):
return ResourceRef(name_or_ref)
elif exists(name_or_ref):
return workspace.map_local_path_to_resource(name_or_ref,
expecting_a_code_resource=False)
else:
raise LineageError("Could not find a resource for '" + name_or_ref +
" in your workspace. Please create a resource"+
" using the dws add command or correct the name.")
"' with role '" + role +
"' in your workspace. Please create a resource"+
" using the 'dws add' command or correct the name. "+
"Currently defined resources are: " +
', '.join(["%s (role %s)" %
(n, workspace.get_resource_role(n))
for n in resource_names]) + '.')
elif isinstance(name_or_ref, ResourceRef):
workspace.validate_resource_name(name_or_ref.name, name_or_ref.subpath)
return name_or_ref
Expand Down Expand Up @@ -117,8 +123,6 @@ def __init__(self, workspace:Workspace,

def find_input_resources_and_return_if_api(self, data, target_data=None) \
-> Optional[ApiResource]:
print("default_input_resource: %s, input_resources=%s" % (self.default_input_resource,
self.lineage.step.input_resources)) # XXX
if hasattr(data, 'resource'):
ref = data.resource
else:
Expand All @@ -144,6 +148,7 @@ def find_input_resources_and_return_if_api(self, data, target_data=None) \

def write_metrics_and_complete(self, metrics):
metrics = _metric_obj_to_json(metrics)
print("Metrics: %s" % repr(metrics))
if self.workspace.verbose:
print("dws>> Metrics: %s" % repr(metrics))
self.lineage.write_results(metrics)
self.lineage.complete()
3 changes: 2 additions & 1 deletion dataworkspaces/resources/api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def save_current_hash(self,comment:Optional[str]=None) -> None:
hashfile = join(scratch, 'hashval.txt')
with open(hashfile, 'w') as f:
f.write(hashval)
print("wrote hashval of '%s' to %s'" % (hashval, hashfile)) # XXX
if self.workspace.verbose:
print("dws>> %s: wrote hashval of '%s' to %s'" % (self.name, hashval, hashfile))
commentfile = join(scratch, 'comment.txt')
if comment is not None:
with open(commentfile, 'w') as f:
Expand Down
1 change: 0 additions & 1 deletion dataworkspaces/utils/lineage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,6 @@ def make_step_lineage(instance:str, step_name:str, start_time:datetime.datetime,
"Step %s at %s"% (step_name, start_time),
for_code=False)
for ref in input_resource_refs] # List[ResourceCert]
print("input_certs: %s" % repr(input_certs)) # XXX
code_certs = [
lineage_store.get_or_create_cert(instance, ref,
"Step %s at %s"% (step_name, start_time),
Expand Down
7 changes: 4 additions & 3 deletions docs/kits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ Scikit-learn
:no-undoc-members:
:members: load_dataset_from_resource,train_and_predict_with_cv,Metrics,BinaryClassificationMetrics,MulticlassClassificationMetrics

TensorFlow 1.x
--------------
.. automodule:: dataworkspaces.kits.tensorflow1
TensorFlow
----------

.. automodule:: dataworkspaces.kits.tensorflow
:no-undoc-members:
:members: add_lineage_to_keras_model_class

2 changes: 1 addition & 1 deletion tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ DATAWORKSPACES:=$(shell cd ../dataworkspaces; pwd)
help:
@echo targets are: test clean mypy pyflakes check help install-rclone-deb

UNIT_TESTS=test_git_utils test_file_utils test_move_results test_snapshots test_push_pull test_local_files_resource test_hashtree test_lineage_utils test_git_fat_integration test_lineage test_jupyter_kit test_api
UNIT_TESTS=test_git_utils test_file_utils test_move_results test_snapshots test_push_pull test_local_files_resource test_hashtree test_lineage_utils test_git_fat_integration test_lineage test_jupyter_kit test_api test_wrapper_utils test_tensorflow

MYPY_KITS=scikit_learn.py jupyter.py tensorflow1.py

Expand Down
54 changes: 54 additions & 0 deletions tests/test_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
import sys
import os.path
from os.path import exists, join
import json

from utils_for_tests import SimpleCase, WS_DIR

try:
import tensorflow
TF_INSTALLED=True
except ImportError:
TF_INSTALLED=False

class TestTensorflowKit(SimpleCase):

@unittest.skipUnless(TF_INSTALLED, "Tensorflow not available")
def test_wrapper(self):
"""This test follows the basic classification tutorial.
"""
import tensorflow as tf
import tensorflow.keras as keras
self._setup_initial_repo(git_resources='results', api_resources='fashion-mnist-data')
from dataworkspaces.kits.tensorflow import add_lineage_to_keras_model_class
keras.Sequential = add_lineage_to_keras_model_class(keras.Sequential,
input_resource='fashion-mnist-data',
verbose=True,
workspace_dir=WS_DIR)
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5)
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print("test accuracy: %s" % test_acc)
results_file = join(WS_DIR, 'results/results.json')
self.assertTrue(exists(results_file), "missing file %s" % results_file)
with open(results_file, 'r') as f:
data = json.load(f)
self.assertAlmostEqual(test_acc, data['metrics']['accuracy'])
self.assertAlmostEqual(test_loss, data['metrics']['loss'])


if __name__ == '__main__':
unittest.main()

50 changes: 50 additions & 0 deletions tests/test_wrapper_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

import unittest
import sys
import os.path
import hashlib

try:
import dataworkspaces
except ImportError:
sys.path.append(os.path.abspath(".."))

from dataworkspaces.kits.wrapper_utils import _add_to_hash

try:
import pandas
except ImportError:
pandas = None

try:
import numpy
except ImportError:
numpy = None


class TestAddToHash(unittest.TestCase):
def setUp(self):
self.hash_state = hashlib.sha1()

@unittest.skipUnless(pandas is not None, 'Pandas not available')
def test_pandas_df(self):
df = pandas.DataFrame({'x1':[1,2,3,4,5],
'x2':[1.5,2.5,3.5,4.5,5.5],
'y':[1,0,0,1,1]})
_add_to_hash(df, self.hash_state)
print(self.hash_state.hexdigest())

@unittest.skipUnless(pandas is not None, 'Pandas not available')
def test_pandas_series(self):
s = pandas.Series([1,0,0,1,1], name='y')
_add_to_hash(s, self.hash_state)
print(self.hash_state.hexdigest())

@unittest.skipUnless(numpy is not None, "Numpy not available")
def test_numpy(self):
a = numpy.arange(45)
_add_to_hash(a, self.hash_state)
print(self.hash_state.hexdigest())

if __name__ == '__main__':
unittest.main()
89 changes: 62 additions & 27 deletions tests/utils_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,7 @@
from dataworkspaces.utils.git_utils import GIT_EXE_PATH
from dataworkspaces.utils.subprocess_utils import find_exe

class BaseCase(unittest.TestCase):
"""utilities to set up an environment that can has two copies of a workspace
and a central bare git repo as the origin.
"""
def setUp(self):
if os.path.exists(TEMPDIR):
shutil.rmtree(TEMPDIR)
os.mkdir(TEMPDIR)
os.mkdir(WS_DIR)
self.dws=find_exe("dws", "Make sure you have enabled your python virtual environment")

def tearDown(self):
if os.path.exists(TEMPDIR):
shutil.rmtree(TEMPDIR)

class HelperMethods:
def _run_dws(self, dws_args, cwd=WS_DIR, env=None):
command = self.dws + ' --verbose --batch '+ ' '.join(dws_args)
print(command + (' [%s]' % cwd))
Expand All @@ -57,18 +43,9 @@ def _run_git(self, git_args, cwd=WS_DIR):
r = subprocess.run(args, cwd=cwd)
r.check_returncode()

def _setup_initial_repo(self, create_resources=None):
if create_resources is not None:
self._run_dws(['init', '--create-resources='+create_resources], cwd=WS_DIR)
else:
self._run_dws(['init'], cwd=WS_DIR)
self._run_git(['init', '--bare', 'workspace_origin.git'],
cwd=TEMPDIR)
self._run_git(['remote', 'add', 'origin', WS_ORIGIN], cwd=WS_DIR)
self._run_dws(['push'], cwd=WS_DIR)

def _clone_second_repo(self):
self._run_dws(['clone', WS_ORIGIN, 'workspace2'], cwd=TEMPDIR)
def _add_api_resource(self, name, role='source-data', cwd=WS_DIR):
self._run_dws(['add', 'api-resource', '--role', role,
'--name', name], cwd=cwd)

def _assert_files_same(self, f1, f2):
self.assertTrue(exists(f1), "Missing file %s" % f1)
Expand All @@ -90,3 +67,61 @@ def _get_resource_set(self, workspace_dir):
names.add(obj['name'])
return names


class BaseCase(HelperMethods, unittest.TestCase):
"""utilities to set up an environment that has two copies of a workspace
and a central bare git repo as the origin.
"""
def setUp(self):
if os.path.exists(TEMPDIR):
shutil.rmtree(TEMPDIR)
os.mkdir(TEMPDIR)
os.mkdir(WS_DIR)
self.dws=find_exe("dws", "Make sure you have enabled your python virtual environment")

def tearDown(self):
if os.path.exists(TEMPDIR):
shutil.rmtree(TEMPDIR)

def _setup_initial_repo(self, create_resources=None):
if create_resources is not None:
self._run_dws(['init', '--create-resources='+create_resources], cwd=WS_DIR)
else:
self._run_dws(['init'], cwd=WS_DIR)
self._run_git(['init', '--bare', 'workspace_origin.git'],
cwd=TEMPDIR)
self._run_git(['remote', 'add', 'origin', WS_ORIGIN], cwd=WS_DIR)
self._run_dws(['push'], cwd=WS_DIR)

def _clone_second_repo(self):
self._run_dws(['clone', WS_ORIGIN, 'workspace2'], cwd=TEMPDIR)



class SimpleCase(HelperMethods, unittest.TestCase):
"""utilities to set up an environment that has a single workspace with
no origin or remote. This is for tests that are not involved
in syncing of workspaces.
"""
def setUp(self):
if os.path.exists(TEMPDIR):
shutil.rmtree(TEMPDIR)
os.mkdir(TEMPDIR)
os.mkdir(WS_DIR)
self.dws=find_exe("dws", "Make sure you have enabled your python virtual environment")
print("created %s" % WS_DIR) # xxX

def tearDown(self):
if os.path.exists(TEMPDIR):
shutil.rmtree(TEMPDIR)

def _setup_initial_repo(self, git_resources=None, api_resources=None):
if git_resources is not None:
self._run_dws(['init', '--create-resources='+git_resources], cwd=WS_DIR)
else:
self._run_dws(['init'], cwd=WS_DIR)
if api_resources is not None:
for rname in api_resources.split(','):
self._add_api_resource(rname, cwd=WS_DIR)


0 comments on commit 10d6fd4

Please sign in to comment.