Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
28a0dd9
commit f8e02c2
Showing
5 changed files
with
313 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import glob | ||
import os | ||
import tempfile | ||
from shutil import rmtree | ||
from zipfile import ZipFile, ZIP_DEFLATED | ||
|
||
from mongoengine import GridFSProxy | ||
|
||
from omegaml.backends import BaseModelBackend | ||
|
||
|
||
class TensorflowSavedModelPredictor(object): | ||
def __init__(self, model_dir): | ||
from tensorflow.contrib import predictor | ||
self.predict_fn = predictor.from_saved_model(model_dir) | ||
|
||
def predict(self, X): | ||
return self.predict_fn(X) | ||
|
||
|
||
class TensorflowEstimatorBackend(BaseModelBackend): | ||
KIND = 'tf.savedmodel' | ||
_model_ext = 'tfsm' | ||
|
||
@classmethod | ||
def supports(self, obj, name, **kwargs): | ||
import tensorflow as tf | ||
return isinstance(obj, tf.estimator.Estimator) | ||
|
||
def _package_savedmodel(self, export_base_dir, filename): | ||
fname = os.path.basename(filename) | ||
zipfname = os.path.join(self.model_store.tmppath, fname) | ||
export_base_dir = glob.glob(os.path.join(export_base_dir, '*'))[0] | ||
with ZipFile(zipfname, 'w', compression=ZIP_DEFLATED) as zipf: | ||
for part in glob.glob(os.path.join(export_base_dir, '**'), recursive=True): | ||
zipf.write(part, os.path.relpath(part, export_base_dir)) | ||
return zipfname | ||
|
||
def _extract_savedmodel(self, packagefname): | ||
lpath = tempfile.mkdtemp() | ||
fname = os.path.basename(packagefname) | ||
mklfname = os.path.join(lpath, fname) | ||
with ZipFile(packagefname) as zipf: | ||
zipf.extractall(lpath) | ||
model = TensorflowSavedModelPredictor(lpath) | ||
rmtree(lpath) | ||
return model | ||
|
||
def put_model(self, obj, name, attributes=None, serving_input_receiver_fn=None, strip_default_attrs=None): | ||
# adapted from https://www.tensorflow.org/guide/saved_model#perform_the_export | ||
export_dir_base = tempfile.mkdtemp() | ||
obj.export_savedmodel(export_dir_base, | ||
serving_input_receiver_fn=serving_input_receiver_fn, | ||
strip_default_attrs=strip_default_attrs) | ||
zipfname = self._package_savedmodel(export_dir_base, name) | ||
with open(zipfname, 'rb') as fzip: | ||
fileid = self.model_store.fs.put( | ||
fzip, filename=self.model_store._get_obj_store_key(name, self._model_ext)) | ||
gridfile = GridFSProxy(grid_id=fileid, | ||
db_alias='omega', | ||
collection_name=self.model_store.bucket) | ||
rmtree(export_dir_base) | ||
os.remove(zipfname) | ||
return self.model_store._make_metadata( | ||
name=name, | ||
prefix=self.model_store.prefix, | ||
bucket=self.model_store.bucket, | ||
kind=self.KIND, | ||
attributes=attributes, | ||
gridfile=gridfile).save() | ||
|
||
def get_model(self, name, version=-1): | ||
filename = self.model_store._get_obj_store_key(name, self._model_ext) | ||
packagefname = os.path.join(self.model_store.tmppath, name) | ||
outf = self.model_store.fs.get_version(filename, version=version) | ||
with open(packagefname, 'wb') as zipf: | ||
zipf.write(outf.read()) | ||
model = self._extract_savedmodel(packagefname) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from time import sleep | ||
from urllib.parse import quote | ||
|
||
from behave import when, then | ||
<<<<<<< 0892f44208812268fc3dec7d81fae03b1cc450fb | ||
from selenium.webdriver.common.keys import Keys | ||
|
||
ACTIVATE_CELL = Keys.ESCAPE, Keys.ENTER | ||
EXEC_CELL = Keys.SHIFT, Keys.ENTER | ||
ADD_CELL_BELOW = Keys.ESCAPE, 'b' | ||
|
||
|
||
class Notebook: | ||
""" | ||
A simple driver for the notebook | ||
""" | ||
|
||
def __init__(self, browser): | ||
self.browser = browser | ||
try: | ||
alert = browser.get_alert() | ||
except: | ||
pass | ||
else: | ||
alert.accept() | ||
|
||
@property | ||
def body(self): | ||
return self.browser.find_by_css('body').first | ||
|
||
@property | ||
def jupyter_home(self): | ||
br = self.browser | ||
br.windows.current = br.windows[0] | ||
return self | ||
|
||
@property | ||
def last_notebook(self): | ||
br = self.browser | ||
br.windows.current = br.windows[-1] | ||
return self | ||
|
||
def login(self): | ||
br = self.browser | ||
assert br.is_element_present_by_id('ipython-main-app', wait_time=2) | ||
br.find_by_id('password_input').fill('omegamlisfun') | ||
br.find_by_id('login_submit').click() | ||
# check that there is actually a connection | ||
assert not br.is_text_present('Server error: Traceback', wait_time=2) | ||
assert not br.is_text_present('Connection refuse', wait_time=2) | ||
|
||
def create_folder(self): | ||
""" | ||
create a folder | ||
""" | ||
br = self.browser | ||
self.jupyter_home | ||
br.find_by_id('new-dropdown-button').click() | ||
br.find_by_text('Folder').click() | ||
sleep(2) | ||
|
||
def create_notebook(self, folder=None): | ||
""" | ||
create a new notebook | ||
""" | ||
br = self.browser | ||
self.jupyter_home | ||
br.find_by_id('new-dropdown-button').click() | ||
br.find_by_text('Python 3').click() | ||
sleep(2) | ||
self.last_notebook | ||
return self | ||
|
||
def open_folder(self, folder=None): | ||
br = self.browser | ||
folder = quote(folder.encode('utf-8')) | ||
item = br.find_link_by_href('/tree/{folder}'.format(**locals()))[0] | ||
item.click() | ||
return self | ||
|
||
def _clean_code(self, code): | ||
return tuple('\n'.join(line.strip() for line in code.split('\n'))) | ||
|
||
def current_cell_exec(self, code): | ||
self.body.type(ACTIVATE_CELL + self._clean_code(code) + EXEC_CELL) | ||
|
||
def new_cell_exec(self, code): | ||
self.body.type(ADD_CELL_BELOW + ACTIVATE_CELL + self._clean_code(code) + EXEC_CELL) | ||
|
||
def current_cell_output(self): | ||
return self.body.find_by_css('.output_subarea pre')[-1].text | ||
======= | ||
|
||
from omegaml.tests.features.util import Notebook | ||
>>>>>>> add more tests | ||
|
||
|
||
@when(u'we open jupyter') | ||
def open_jupyter(ctx): | ||
br = ctx.browser | ||
br.visit('http://localhost:8888') | ||
nb = Notebook(br) | ||
login_required = br.is_text_present('Password', wait_time=2) | ||
login_required |= br.is_text_present('token', wait_time=2) | ||
if login_required: | ||
nb.login() | ||
nb.jupyter_home | ||
|
||
@when(u'we create a notebook') | ||
def step_impl(ctx): | ||
br = ctx.browser | ||
nb = Notebook(br) | ||
nb.create_notebook() | ||
# test code execution | ||
code = """ | ||
print('hello') | ||
""".strip() | ||
nb.current_cell_exec(code) | ||
sleep(1) | ||
assert nb.current_cell_output() == 'hello' | ||
|
||
|
||
@when(u'we create a folder') | ||
def create_folder(ctx): | ||
br = ctx.browser | ||
nb = Notebook(br) | ||
nb.create_folder() | ||
nb.open_folder('Untitled Folder') | ||
|
||
|
||
@then(u'we can list datasets in omegaml') | ||
def step_impl(ctx): | ||
# test omegaml functionality | ||
br = ctx.browser | ||
nb = Notebook(br) | ||
code = """ | ||
import omegaml as om | ||
om.datasets.put(['sample'], 'sample', append=False) | ||
om.datasets.list('sample') | ||
""".strip() | ||
nb.new_cell_exec(code) | ||
sleep(3) | ||
assert nb.current_cell_output() == "['sample']" | ||
|
||
@then(u'we can add a notebook in the folder') | ||
def step_impl(ctx): | ||
br = ctx.browser | ||
br.visit('http://localhost:8888') | ||
nb = Notebook(br) | ||
nb.jupyter_home | ||
nb.open_folder('Untitled Folder') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from unittest import TestCase | ||
|
||
from omegaml import Omega | ||
from omegaml.backends.tfestimator import TensorflowEstimatorBackend, TensorflowSavedModelPredictor | ||
|
||
|
||
class TensorflowEstimatorBackendTests(TestCase): | ||
def setUp(self): | ||
self.om = Omega() | ||
self.om.models.register_backend(TensorflowEstimatorBackend.KIND, TensorflowEstimatorBackend) | ||
|
||
def _build_model(self): | ||
# build a dummy model for testing. does not need to make sense | ||
import tensorflow as tf | ||
keras = tf.keras | ||
Sequential = keras.models.Sequential | ||
Dense = keras.layers.Dense | ||
Dropout = keras.layers.Dropout | ||
SGD = keras.optimizers.SGD | ||
|
||
# Generate dummy data | ||
import numpy as np | ||
x_train = np.random.random((1000, 20)) | ||
y_train = keras.utils.to_categorical(np.random.randint(10, size=(1000, 1)), num_classes=10) | ||
x_test = np.random.random((100, 20)) | ||
y_test = keras.utils.to_categorical(np.random.randint(10, size=(100, 1)), num_classes=10) | ||
|
||
model = Sequential() | ||
# Dense(64) is a fully-connected layer with 64 hidden units. | ||
# in the first layer, you must specify the expected input data shape: | ||
# here, 20-dimensional vectors. | ||
model.add(Dense(64, activation='relu', input_dim=20)) | ||
model.add(Dropout(0.5)) | ||
model.add(Dense(64, activation='relu')) | ||
model.add(Dropout(0.5)) | ||
model.add(Dense(10, activation='softmax')) | ||
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) | ||
model.compile(loss='categorical_crossentropy', | ||
optimizer=sgd, | ||
metrics=['accuracy']) | ||
|
||
# https://www.tensorflow.org/guide/estimators | ||
est_model = tf.keras.estimator.model_to_estimator(keras_model=model) | ||
train_input_fn = tf.estimator.inputs.numpy_input_fn( | ||
x={"dense_input": x_train}, | ||
y=y_train, | ||
num_epochs=1, | ||
shuffle=False) | ||
|
||
est_model.train(train_input_fn) | ||
return est_model | ||
|
||
def test_save_load(self): | ||
import tensorflow as tf | ||
import numpy as np | ||
om = self.om | ||
model = self._build_model() | ||
|
||
# https://www.tensorflow.org/guide/saved_model#prepare_serving_inputs | ||
default_batch_size = 1 | ||
feature_spec = {'dense_input': tf.FixedLenFeature(dtype=np.int64, shape=(1,))} | ||
|
||
def serving_input_receiver_fn(): | ||
"""An input receiver that expects a serialized tf.Example.""" | ||
serialized_tf_example = tf.placeholder(dtype=tf.string, | ||
shape=[default_batch_size], | ||
name='input_example_tensor') | ||
receiver_tensors = {'examples': serialized_tf_example} | ||
features = tf.parse_example(serialized_tf_example, feature_spec) | ||
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors) | ||
|
||
om.models.put(model, 'estimator-model', | ||
serving_input_receiver_fn=serving_input_receiver_fn) | ||
self.assertIn('estimator-model', om.models.list()) | ||
model_ = om.models.get('estimator-model') | ||
self.assertIsInstance(model_, TensorflowSavedModelPredictor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters