Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing TensorFlow 2.0 Compatibility #1872

Merged
merged 23 commits into from Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
eb4dd4e
first prototype
juntai-zheng Sep 24, 2019
30f66c9
added searching for default session in load_model, modified and added…
juntai-zheng Sep 24, 2019
004b74e
quick lint
juntai-zheng Sep 24, 2019
e1c8319
changed ==2.0.0 to >=2.0.0
juntai-zheng Sep 24, 2019
448ee1c
removed duplicate code in load_model
juntai-zheng Sep 24, 2019
e4e82e9
removing examples for later PR
juntai-zheng Sep 25, 2019
32d6d47
doc changes
juntai-zheng Sep 25, 2019
d4aa68a
refactored TFWrapper into two classes for the different TF versions
juntai-zheng Sep 25, 2019
aa74201
lint for travis
juntai-zheng Sep 25, 2019
c375ba7
another lint
juntai-zheng Sep 25, 2019
8aa2017
travis bash fix
juntai-zheng Sep 25, 2019
94e9d81
added python specific test fixes for generator function
juntai-zheng Sep 26, 2019
bf86957
added pylint ignore comment
juntai-zheng Sep 26, 2019
842d692
Change ValueError to MLflowException and update error message
juntai-zheng Sep 26, 2019
ea87194
fixed long line
juntai-zheng Sep 26, 2019
cea5161
deduplicated _load_tensorflow_saved_model code and fixed appropriate …
juntai-zheng Sep 26, 2019
a53d08d
Update examples/tensorflow/train_predict.py whitespace
juntai-zheng Sep 26, 2019
12242e8
Minor wording in docs of mlflow/tensorflow.py
juntai-zheng Sep 26, 2019
85a89db
refactoring
juntai-zheng Sep 27, 2019
cb1fa8d
ran manual test for loading 1.X model in 2.0, fixed creating tensor o…
juntai-zheng Sep 27, 2019
94c6f6f
Update tensorflow.py
smurching Sep 27, 2019
8818849
added docs under log_model for pyfunc
juntai-zheng Sep 27, 2019
cbb3c6c
"pulling"
juntai-zheng Sep 27, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
168 changes: 124 additions & 44 deletions mlflow/tensorflow.py
Expand Up @@ -26,6 +26,7 @@
import mlflow
import tensorflow
import mlflow.keras
from distutils.version import LooseVersion
from tensorflow.keras.callbacks import Callback, TensorBoard # pylint: disable=import-error
from mlflow import pyfunc
from mlflow.exceptions import MlflowException
Expand Down Expand Up @@ -186,20 +187,25 @@ def _validate_saved_model(tf_saved_model_dir, tf_meta_graph_tags, tf_signature_d
Validate the TensorFlow SavedModel by attempting to load it in a new TensorFlow graph.
If the loading process fails, any exceptions thrown by TensorFlow are propagated.
"""
validation_tf_graph = tensorflow.Graph()
validation_tf_sess = tensorflow.Session(graph=validation_tf_graph)
with validation_tf_graph.as_default():
if LooseVersion(tensorflow.__version__) < LooseVersion('2.0.0'):
validation_tf_graph = tensorflow.Graph()
validation_tf_sess = tensorflow.Session(graph=validation_tf_graph)
with validation_tf_graph.as_default():
_load_tensorflow_saved_model(tf_saved_model_dir=tf_saved_model_dir,
tf_sess=validation_tf_sess,
tf_meta_graph_tags=tf_meta_graph_tags,
tf_signature_def_key=tf_signature_def_key)
else:
_load_tensorflow_saved_model(tf_saved_model_dir=tf_saved_model_dir,
tf_sess=validation_tf_sess,
tf_meta_graph_tags=tf_meta_graph_tags,
tf_signature_def_key=tf_signature_def_key)


def load_model(model_uri, tf_sess):
def load_model(model_uri, tf_sess=None):
"""
Load an MLflow model that contains the TensorFlow flavor from the specified path.

*This method must be called within a TensorFlow graph context.*
*With TensorFlow version <2.0.0, this method must be called within a TensorFlow graph context.*

:param model_uri: The location, in URI format, of the MLflow model. For example:

Expand All @@ -212,10 +218,17 @@ def load_model(model_uri, tf_sess):
`Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
artifact-locations>`_.

:param tf_sess: The TensorFlow session in which to load the model.
:return: A TensorFlow signature definition of type:

:param tf_sess: The TensorFlow session in which to load the model. If using TensorFlow
version >= 2.0.0, this argument is ignored. If using TensorFlow <2.0.0, if no
session is passed to this function, MLflow will attempt to load the model using
the default TensorFlow session. If no default session is available, then the
function raises an exception.
:return: For TensorFlow < 2.0.0, a TensorFlow signature definition of type:
``tensorflow.core.protobuf.meta_graph_pb2.SignatureDef``. This defines the input and
output tensors for model inference.
For TensorFlow >= 2.0.0, A callable graph (tf.function) that takes inputs and
returns inferences.

>>> import mlflow.tensorflow
>>> import tensorflow as tf
Expand All @@ -229,22 +242,38 @@ def load_model(model_uri, tf_sess):
>>> output_tensors = [tf_graph.get_tensor_by_name(output_signature.name)
>>> for _, output_signature in signature_def.outputs.items()]
"""

if LooseVersion(tensorflow.__version__) < LooseVersion('2.0.0'):
juntai-zheng marked this conversation as resolved.
Show resolved Hide resolved
if not tf_sess:
tf_sess = tensorflow.get_default_session()
if not tf_sess:
raise MlflowException("No TensorFlow session found while calling load_model()." +
"You can set the default Tensorflow session before calling" +
" load_model via `session.as_default()`, or directly pass " +
"a session in which to load the model via the tf_sess " +
"argument.")

else:
if tf_sess:
warnings.warn("A TensorFlow session was passed into load_model, but the " +
"currently used version is TF 2.0 where sessions are deprecated. " +
"The tf_sess argument will be ignored.", FutureWarning)
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
tf_saved_model_dir, tf_meta_graph_tags, tf_signature_def_key =\
_get_and_parse_flavor_configuration(model_path=local_model_path)
return _load_tensorflow_saved_model(tf_saved_model_dir=tf_saved_model_dir, tf_sess=tf_sess,
return _load_tensorflow_saved_model(tf_saved_model_dir=tf_saved_model_dir,
tf_meta_graph_tags=tf_meta_graph_tags,
tf_signature_def_key=tf_signature_def_key)
tf_signature_def_key=tf_signature_def_key,
tf_sess=tf_sess)


def _load_tensorflow_saved_model(tf_saved_model_dir, tf_sess, tf_meta_graph_tags,
tf_signature_def_key):
def _load_tensorflow_saved_model(tf_saved_model_dir, tf_meta_graph_tags, tf_signature_def_key,
tf_sess=None):
"""
Load a specified TensorFlow model consisting of a TensorFlow metagraph and signature definition
from a serialized TensorFlow ``SavedModel`` collection.

:param tf_saved_model_dir: The local filesystem path or run-relative artifact path to the model.
:param tf_sess: The TensorFlow session in which to load the metagraph.
:param tf_meta_graph_tags: A list of tags identifying the model's metagraph within the
serialized ``SavedModel`` object. For more information, see the
``tags`` parameter of the `tf.saved_model.builder.SavedModelBuilder
Expand All @@ -255,17 +284,30 @@ def _load_tensorflow_saved_model(tf_saved_model_dir, tf_sess, tf_meta_graph_tags
signature definition mapping. For more information, see the
``signature_def_map`` parameter of the
``tf.saved_model.builder.SavedModelBuilder`` method.
:return: A TensorFlow signature definition of type:
:param tf_sess: The TensorFlow session in which to load the metagraph.
Required in TensorFlow versions < 2.0.0. Unused in TensorFlow versions >= 2.0.0
:return: For TensorFlow versions < 2.0.0:
A TensorFlow signature definition of type:
``tensorflow.core.protobuf.meta_graph_pb2.SignatureDef``. This defines input and
output tensors within the specified metagraph for inference.
For TensorFlow versions >= 2.0.0:
A callable graph (tensorflow.function) that takes inputs and returns inferences.
"""
meta_graph_def = tensorflow.saved_model.loader.load(
if LooseVersion(tensorflow.__version__) < LooseVersion('2.0.0'):
loaded = tensorflow.saved_model.loader.load(
sess=tf_sess,
tags=tf_meta_graph_tags,
export_dir=tf_saved_model_dir)
if tf_signature_def_key not in meta_graph_def.signature_def:
raise MlflowException("Could not find signature def key %s" % tf_signature_def_key)
return meta_graph_def.signature_def[tf_signature_def_key]
loaded_sig = loaded.signature_def
else:
loaded = tensorflow.saved_model.load( # pylint: disable=no-value-for-parameter
tags=tf_meta_graph_tags,
export_dir=tf_saved_model_dir)
loaded_sig = loaded.signatures
if tf_signature_def_key not in loaded_sig:
raise MlflowException("Could not find signature def key %s. Available keys are: %s"
% (tf_signature_def_key, list(loaded_sig.keys())))
return loaded_sig[tf_signature_def_key]


def _get_and_parse_flavor_configuration(model_path):
Expand Down Expand Up @@ -298,21 +340,26 @@ def _load_pyfunc(path):
"""
tf_saved_model_dir, tf_meta_graph_tags, tf_signature_def_key =\
_get_and_parse_flavor_configuration(model_path=path)

tf_graph = tensorflow.Graph()
tf_sess = tensorflow.Session(graph=tf_graph)
with tf_graph.as_default():
signature_def = _load_tensorflow_saved_model(
tf_saved_model_dir=tf_saved_model_dir, tf_sess=tf_sess,
tf_meta_graph_tags=tf_meta_graph_tags, tf_signature_def_key=tf_signature_def_key)

return _TFWrapper(tf_sess=tf_sess, tf_graph=tf_graph, signature_def=signature_def)
if LooseVersion(tensorflow.__version__) < LooseVersion('2.0.0'):
tf_graph = tensorflow.Graph()
tf_sess = tensorflow.Session(graph=tf_graph)
with tf_graph.as_default():
signature_def = _load_tensorflow_saved_model(
tf_saved_model_dir=tf_saved_model_dir, tf_sess=tf_sess,
tf_meta_graph_tags=tf_meta_graph_tags, tf_signature_def_key=tf_signature_def_key)

return _TFWrapper(tf_sess=tf_sess, tf_graph=tf_graph, signature_def=signature_def)
else:
loaded_model = tensorflow.saved_model.load( # pylint: disable=no-value-for-parameter
export_dir=tf_saved_model_dir,
tags=tf_meta_graph_tags)
return _TF2Wrapper(infer=loaded_model.signatures[tf_signature_def_key])


class _TFWrapper(object):
"""
Wrapper class that exposes a TensorFlow model for inference via a ``predict`` function such that
``predict(data: pandas.DataFrame) -> pandas.DataFrame``.
``predict(data: pandas.DataFrame) -> pandas.DataFrame``. For TensorFlow versions < 2.0.0.
"""
def __init__(self, tf_sess, tf_graph, signature_def):
"""
Expand All @@ -323,31 +370,67 @@ def __init__(self, tf_sess, tf_graph, signature_def):
"""
self.tf_sess = tf_sess
self.tf_graph = tf_graph
# We assume that input keys in the signature definition correspond to input DataFrame column
# names
# We assume that input keys in the signature definition correspond to
# input DataFrame column names
self.input_tensor_mapping = {
tensor_column_name: tf_graph.get_tensor_by_name(tensor_info.name)
for tensor_column_name, tensor_info in signature_def.inputs.items()
tensor_column_name: tf_graph.get_tensor_by_name(tensor_info.name)
for tensor_column_name, tensor_info in signature_def.inputs.items()
}
# We assume that output keys in the signature definition correspond to output DataFrame
# column names
# We assume that output keys in the signature definition correspond to
# output DataFrame column names
self.output_tensors = {
sigdef_output: tf_graph.get_tensor_by_name(tnsr_info.name)
for sigdef_output, tnsr_info in signature_def.outputs.items()
sigdef_output: tf_graph.get_tensor_by_name(tnsr_info.name)
for sigdef_output, tnsr_info in signature_def.outputs.items()
}

def predict(self, df):
with self.tf_graph.as_default():
# Build the feed dict, mapping input tensors to DataFrame column values.
feed_dict = {
self.input_tensor_mapping[tensor_column_name]: df[tensor_column_name].values
for tensor_column_name in self.input_tensor_mapping.keys()
self.input_tensor_mapping[tensor_column_name]: df[tensor_column_name].values
for tensor_column_name in self.input_tensor_mapping.keys()
}
raw_preds = self.tf_sess.run(self.output_tensors, feed_dict=feed_dict)
pred_dict = {column_name: values.ravel() for column_name, values in raw_preds.items()}
pred_dict = {column_name: values.ravel() for
column_name, values in raw_preds.items()}
return pandas.DataFrame(data=pred_dict)


class _TF2Wrapper(object):
"""
Wrapper class that exposes a TensorFlow model for inference via a ``predict`` function such that
``predict(data: pandas.DataFrame) -> pandas.DataFrame``. For TensorFlow versions >= 2.0.0.
"""
def __init__(self, infer):
"""
:param infer: Tensorflow function returned by a saved model that is used for inference.
"""
self.infer = infer

def predict(self, df):
feed_dict = {}
for df_col_name in list(df):
# If there are multiple columns with the same name, selecting the shared name
# from the DataFrame will result in another DataFrame containing the columns
# with the shared name. TensorFlow cannot make eager tensors out of pandas
# DataFrames, so we convert the DataFrame to a numpy array here.
val = df[df_col_name]
if isinstance(val, pandas.DataFrame):
val = val.values
feed_dict[df_col_name] = tensorflow.constant(val)
raw_preds = self.infer(**feed_dict)
pred_dict = {
col_name: raw_preds[col_name].numpy() for col_name in raw_preds.keys()
}
for col in pred_dict.keys():
if all(len(element) == 1 for element in pred_dict[col]):
pred_dict[col] = pred_dict[col].ravel()
else:
pred_dict[col] = pred_dict[col].tolist()

return pandas.DataFrame.from_dict(data=pred_dict)


class __MLflowTfKerasCallback(Callback):
"""
Callback for auto-logging parameters (we rely on TensorBoard for metrics).
Expand Down Expand Up @@ -485,12 +568,9 @@ def autolog(every_n_iter=100):
global _LOG_EVERY_N_STEPS
_LOG_EVERY_N_STEPS = every_n_iter

from distutils.version import StrictVersion

if StrictVersion(tensorflow.__version__) < StrictVersion('1.12') \
or StrictVersion(tensorflow.__version__) >= StrictVersion('2.0'):
if LooseVersion(tensorflow.__version__) < LooseVersion('1.12'):
warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
"1.12 <= v < 2.0.0 are supported.")
"1.12 <= v <= 2.0.0 are supported.")
return

try:
Expand Down
98 changes: 98 additions & 0 deletions tests/tensorflow/iris_data_utils.py
@@ -0,0 +1,98 @@
# From https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py
# This file is the example used by TensorFlow to get users started. This code is used for testing.
import pandas as pd
import tensorflow as tf

TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']


def maybe_download():
train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)

return train_path, test_path


def load_data(y_name='Species'):
"""Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
train_path, test_path = maybe_download()

train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
train_x, train_y = train, train.pop(y_name)

test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
test_x, test_y = test, test.pop(y_name)

return (train_x, train_y), (test_x, test_y)


def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

# Return the dataset.
return dataset


def eval_input_fn(features, labels, batch_size):
"""An input function for evaluation or prediction"""
features = dict(features)
if labels is None:
# No labels, use only features.
inputs = features
else:
inputs = (features, labels)

# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices(inputs)

# Batch the examples
assert batch_size is not None, "batch_size must not be None"
dataset = dataset.batch(batch_size)

# Return the dataset.
return dataset


# The remainder of this file contains a simple example of a csv parser,
# implemented using the `Dataset` class.

# `tf.parse_csv` sets the types of the outputs to match the examples given in
# the `record_defaults` argument.
CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]


def _parse_line(line):
# Decode the line into its fields
fields = tf.decode_csv(line, record_defaults=CSV_TYPES)

# Pack the result into a dictionary
features = dict(zip(CSV_COLUMN_NAMES, fields))

# Separate the label from the features
label = features.pop('Species')

return features, label


def csv_input_fn(csv_path, batch_size):
# Create a dataset containing the text lines.
dataset = tf.data.TextLineDataset(csv_path).skip(1)

# Parse each line.
dataset = dataset.map(_parse_line)

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

# Return the dataset.
return dataset