-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
work on tensorflow wrapper, which now supports 1.x and 2.0
- Loading branch information
Showing
9 changed files
with
204 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import unittest | ||
import sys | ||
import os.path | ||
from os.path import exists, join | ||
import json | ||
|
||
from utils_for_tests import SimpleCase, WS_DIR | ||
|
||
try: | ||
import tensorflow | ||
TF_INSTALLED=True | ||
except ImportError: | ||
TF_INSTALLED=False | ||
|
||
class TestTensorflowKit(SimpleCase): | ||
|
||
@unittest.skipUnless(TF_INSTALLED, "Tensorflow not available") | ||
def test_wrapper(self): | ||
"""This test follows the basic classification tutorial. | ||
""" | ||
import tensorflow as tf | ||
import tensorflow.keras as keras | ||
self._setup_initial_repo(git_resources='results', api_resources='fashion-mnist-data') | ||
from dataworkspaces.kits.tensorflow import add_lineage_to_keras_model_class | ||
keras.Sequential = add_lineage_to_keras_model_class(keras.Sequential, | ||
input_resource='fashion-mnist-data', | ||
verbose=True, | ||
workspace_dir=WS_DIR) | ||
fashion_mnist = keras.datasets.fashion_mnist | ||
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() | ||
train_images = train_images / 255.0 | ||
test_images = test_images / 255.0 | ||
model = keras.Sequential([ | ||
keras.layers.Flatten(input_shape=(28, 28)), | ||
keras.layers.Dense(128, activation='relu'), | ||
keras.layers.Dense(10, activation='softmax') | ||
]) | ||
model.compile(optimizer='adam', | ||
loss='sparse_categorical_crossentropy', | ||
metrics=['accuracy']) | ||
model.fit(train_images, train_labels, epochs=5) | ||
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) | ||
print("test accuracy: %s" % test_acc) | ||
results_file = join(WS_DIR, 'results/results.json') | ||
self.assertTrue(exists(results_file), "missing file %s" % results_file) | ||
with open(results_file, 'r') as f: | ||
data = json.load(f) | ||
self.assertAlmostEqual(test_acc, data['metrics']['accuracy']) | ||
self.assertAlmostEqual(test_loss, data['metrics']['loss']) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
|
||
import unittest | ||
import sys | ||
import os.path | ||
import hashlib | ||
|
||
try: | ||
import dataworkspaces | ||
except ImportError: | ||
sys.path.append(os.path.abspath("..")) | ||
|
||
from dataworkspaces.kits.wrapper_utils import _add_to_hash | ||
|
||
try: | ||
import pandas | ||
except ImportError: | ||
pandas = None | ||
|
||
try: | ||
import numpy | ||
except ImportError: | ||
numpy = None | ||
|
||
|
||
class TestAddToHash(unittest.TestCase): | ||
def setUp(self): | ||
self.hash_state = hashlib.sha1() | ||
|
||
@unittest.skipUnless(pandas is not None, 'Pandas not available') | ||
def test_pandas_df(self): | ||
df = pandas.DataFrame({'x1':[1,2,3,4,5], | ||
'x2':[1.5,2.5,3.5,4.5,5.5], | ||
'y':[1,0,0,1,1]}) | ||
_add_to_hash(df, self.hash_state) | ||
print(self.hash_state.hexdigest()) | ||
|
||
@unittest.skipUnless(pandas is not None, 'Pandas not available') | ||
def test_pandas_series(self): | ||
s = pandas.Series([1,0,0,1,1], name='y') | ||
_add_to_hash(s, self.hash_state) | ||
print(self.hash_state.hexdigest()) | ||
|
||
@unittest.skipUnless(numpy is not None, "Numpy not available") | ||
def test_numpy(self): | ||
a = numpy.arange(45) | ||
_add_to_hash(a, self.hash_state) | ||
print(self.hash_state.hexdigest()) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters