Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle PyTorch models #198 #264

Merged
merged 19 commits into from Aug 17, 2018
Merged

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented Aug 8, 2018

Adderesses #198

Here is a basic version of how to handle pytorch models similarly as it is done for h2o etc.
Note that pyfunc's predict function work on numpy arrays and not pandas dataframes.

@aarondav what do you think?

Thanks

.travis.yml Outdated
@@ -39,6 +39,8 @@ install:
# Install protoc
- wget https://github.com/google/protobuf/releases/download/v3.6.0/protoc-3.6.0-linux-x86_64.zip -O /travis-install/protoc.zip
- sudo unzip /travis-install/protoc.zip -d /usr
# Install pytorch cpu only (no wheels on pip)
- conda install -c pytorch pytorch-cpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, it would be better to add these to tax-requirements.txt so people can also run the tests locally. Does that work if you do it or do you need to get PyTorch through Conda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no CPU only torch wheels on PyPI, so we can insert torch in tox requirement and ask people to download ~500MB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However it is possible to download a specific version from pytorch server: e.g.
http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl

@vfdev-5 vfdev-5 force-pushed the pytorch_save_model branch 2 times, most recently from 707ee09 to f4c4c18 Compare August 8, 2018 07:45
@codecov-io
Copy link

codecov-io commented Aug 8, 2018

Codecov Report

Merging #264 into master will increase coverage by 0.4%.
The diff coverage is 94.91%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master     #264     +/-   ##
=========================================
+ Coverage   56.02%   56.43%   +0.4%     
=========================================
  Files         112      113      +1     
  Lines        5560     5619     +59     
=========================================
+ Hits         3115     3171     +56     
- Misses       2445     2448      +3
Impacted Files Coverage Δ
mlflow/pytorch.py 94.91% <94.91%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7bd4c51...4f9d2ed. Read the comment docs.

self.pytorch_model = pytorch_model

def predict(self, data, device='cpu'):
assert isinstance(data, np.ndarray), "Input data should be numpy.ndarray"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this support inputting a pandas.Dataframe? For reference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juntai-zheng no, it does not support pandas.DataFrame.
How could we put an image into a dataframe ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, but currently implementors of pyfunc do expect that they can hand predict a DataFrame and get a DataFrame out.

We should have a follow-up discussion on extending the predict interface to be able to support n-dimensional. Until then, perhaps we could have two pathways, one where data is an np.ndarray, and the other where the input is a pandas DataFrame. In the latter case, we can directly convert it to an ndarray, and we also convert the result to a pandas DataFrame. This does assume the input is 2D for now, which is pretty constraining.

Copy link
Contributor Author

@vfdev-5 vfdev-5 Aug 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aarondav just wonder whether predict functions for other frameworks (sklearn, h2o, tf etc) do not convert DataFrame to numpy array to execute a processing ?
What are the advantages of using DataFrame vs 2D numpy array ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, good question. The idea of taking a DataFrame is to integrate with other frameworks. For example, you can load a pyfunc into a Spark UDF:

Return a Spark UDF that can be used to invoke the Python function formatted model.
Parameters passed to the UDF are forwarded to the model as a DataFrame where the names are
simply ordinals (0, 1, ...).
Example:
.. code:: python
predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model")
df.withColumn("prediction", predict("name", "age")).show()

In order for this to work, we use the Pandas Vectorized UDF support in Spark to efficiently convert Spark DataFrames to Pandas DataFrames, and then pass them into the predict() function. If the predict() function cannot take Pandas DataFrames, then it simply cannot be used as a Spark UDF.

return x.view(x.shape[0], -1)


class TestModelExport(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're currently working on moving our testing suites to pytest. Do you think you could refactor your tests to make use of pytest? You can see an example at test_databricks.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juntai-zheng yes, I can refactor that

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Aug 14, 2018

Test refactored to pytest.
Adapted test code to be similar as in Keras. Predictions on pandas DataFrames

@smurching smurching self-requested a review August 15, 2018 02:39
Copy link
Collaborator

@smurching smurching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @vfdev-5, this looks pretty good - just had a few questions, namely one about handling multi-output models.


# Loading pyfunc model
pyfunc_loaded = mlflow.pyfunc.load_pyfunc(path)
assert np.all(pyfunc_loaded.predict(x).values[:, 0] == predicted)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Just to confirm, is the [:, 0] indexing necessary because the PyFunc model generates a single output column containing a N-dimensional array with the same shape as the PyTorch model's output tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, [:, 0] is necessary because the output of pyfunc_loaded.predict(x) is pandas DataFrame of shape (N, 1), but predicted is (N, ). Wrapping by DataFrame gives such behavior.


# This maybe replaced by a warning and then try/except torch.load
flavor = mlflow_model.flavors[FLAVOR_NAME]
assert torch.__version__ == flavor["pytorch_version"], \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Instead of an assert, could we raise an Exception here? Asserts can be disabled by the user. Same with the FLAVOR_NAME check above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smurching raise ValueError or RuntimeError ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, IMO a ValueError makes sense here - in general our exception handling/raising isn't super consistent across the different model implementations (we often just raise an Exception`), it's something we should eventually address.

self.pytorch_model.eval()
with torch.no_grad():
input_tensor = torch.from_numpy(data.values.astype(np.float32)).to(device)
preds = self.pytorch_model(input_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: do you know if self.pytorch_model(input_tensor) always outputs a single tensor? Was trying to look at the docs for torch.nn.Module.__call__, (see here) the return type of __call__ doesn't seem to be documented anywhere

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this example, it seems that a model might be able to return multiple values. Maybe for now we can raise an exception in that case & think about how to support multi-output models in the future. Alternatively, we could try putting each output tensor into its own DataFrame column (although we'd need to be careful about documenting the ordering of the output columns). Either way, we should document the behavior in our RST docs

Copy link
Contributor Author

@vfdev-5 vfdev-5 Aug 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually there is no restristriction on the number of tensors in the input and in the output of the model. This should be also true for other frameworks like TF and Keras.
However I think such check should be done on the initialization time and not at prediction. However, in PyTorch such introspection can be rather difficult (maybe we might just check with a random (N, M) tensor).
I think it would be better to mention in the documentation that models used with pyfunc should be single input tensor of shape (N, M) and single output tensor of shape (N, K).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, yeah mentioning the restriction on model input/output (one input tensor & one output tensor) in the documentation makes sense to me. IMO it would still be nice to raise an exception here if preds isn't a tensor, because otherwise (e.g. if preds is a tuple) the user will see an error like "tuple has no attribute numpy()", which doesn't give them an idea of how to fix the error.

How about something like:

if not isinstance(preds, torch.Tensor):
    raise RuntimeError("Expected PyTorch model to output a single output tensor, but got output of type %s" % type(preds))

Copy link
Collaborator

@smurching smurching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responding to some questions, thanks for the continued work on this @vfdev-5 :)

mlflow_model = Model.load(mlflow_model_path)

assert FLAVOR_NAME in mlflow_model.flavors, \
"Stored model can not be loaded with mlflow.pytorch"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we update this error message to mention the flavors that are present? e.g. "Could not find flavor %s amongst available flavors %s, unable to load stored model" % (FLAVOR_NAME, mlflow_model.flavors)"


# This maybe replaced by a warning and then try/except torch.load
flavor = mlflow_model.flavors[FLAVOR_NAME]
assert torch.__version__ == flavor["pytorch_version"], \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, IMO a ValueError makes sense here - in general our exception handling/raising isn't super consistent across the different model implementations (we often just raise an Exception`), it's something we should eventually address.

self.pytorch_model.eval()
with torch.no_grad():
input_tensor = torch.from_numpy(data.values.astype(np.float32)).to(device)
preds = self.pytorch_model(input_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, yeah mentioning the restriction on model input/output (one input tensor & one output tensor) in the documentation makes sense to me. IMO it would still be nice to raise an exception here if preds isn't a tensor, because otherwise (e.g. if preds is a tuple) the user will see an error like "tuple has no attribute numpy()", which doesn't give them an idea of how to fix the error.

How about something like:

if not isinstance(preds, torch.Tensor):
    raise RuntimeError("Expected PyTorch model to output a single output tensor, but got output of type %s" % type(preds))

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Aug 16, 2018

@smurching changed asserts into if -> raise *Error, added check on preds type, updated error messages. Added tests on raises.

Let me know if I forgot something you mentioned

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Aug 16, 2018

@smurching could you please restart the test.

@mateiz
Copy link
Contributor

mateiz commented Aug 16, 2018

I just restarted it -- hope this helps.

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Aug 16, 2018

Honestly, I do not understand why the CI fails randomly: at first with python 2.7, now with python 3.6 ...

@smurching
Copy link
Collaborator

@vfdev-5 sorry about that, I'll take a look at the build today. For now I'm going to try adding some debug options to the travis build to print some more output, hope you don't mind.

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Aug 16, 2018

@smurching no problems, but I have a feeling is that it is related to travis infrastructure + a lot of tests of different loads

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Aug 17, 2018

@smurching thanks for working on the PR ! Now the code and docs start look better :)

raise TypeError("Input data should be pandas.DataFrame")
self.pytorch_model.eval()
with torch.no_grad():
input_tensor = torch.from_numpy(data.values.astype(np.float32)).to(device)
Copy link
Contributor Author

@vfdev-5 vfdev-5 Aug 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have a small correction to add here. We can have an incoherence between device and **kwargs of load_pyfunc(path, **kwargs). Namely if model is loaded on GPU and than user sets device as CPU and vice-versa the following operation preds = self.pytorch_model(input_tensor) will fail.

Force model to be on the same device as the input tensor.
@smurching
Copy link
Collaborator

smurching commented Aug 17, 2018

No problem, yeah just wanted to tune up the docstrings - this otherwise looked pretty good. Would be great if you could address the issue you discussed with the device argument, glad you caught that. Maybe we just shouldn't have a device argument to the PyFunc's predict method?

Looks like I also introduced a linter error (mlflow/pytorch.py:100:88: W291 trailing whitespace) - if you're able to fix it that'd be great, otherwise happy to make the fix (sorry for the bug in the first place).

There was one other design question I wanted to bring up, I'll comment in the code.

Thanks again for working on this :)

Copy link
Collaborator

@smurching smurching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had one design question about a potential breaking change we may need to make in the future, otherwise this is looking great - thanks @vfdev-5 :)

"""
Log a PyTorch model as an MLflow artifact for the current run.

:param pytorch_model: PyTorch model to be saved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we document that pytorch_model must accept a single input tensor & produce a single output tensor here? We could say something like:

:param pytorch_model: PyTorch model to be saved. Must accept a single input tensor and produce a single output tensor.

"""
Save a PyTorch model to a path on the local file system.

:param pytorch_model: PyTorch model to be saved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, could we document the limitation on input/output tensors here?

class _PyTorchWrapper(object):
"""
Wrapper class that creates a predict function such that
predict(data: ndarray) -> model's output as numpy.ndarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we update this docstring since our input / output are now pd.DataFrames instead of numpy.ndarrays?

self.pytorch_model.to(device)
self.pytorch_model.eval()
with torch.no_grad():
input_tensor = torch.from_numpy(data.values.astype(np.float32)).to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm noticing this late, wanted to get your opinion: currently users can only predict on models with a single input tensor, which precludes using MLflow with models that take input tensors of different types (e.g. when predicting on images, maybe we'd have a float tensor containing pixel values & integer values corresponding to image height/width). IMO this restriction is fine as long as we can extend this implementation to support multi-input models in the future.

My only concern with the current implementation is that we convert the entire DataFrame to a single tensor & pass that as the sole input to the model, which requires us to make breaking changes down the line if we want to instead pass each column as an input tensor to the model. Maybe we should instead require that the input DataFrame df contains a single column of numpy arrays & convert that column to an input tensor (e.g. via np.hstack(data[data.columns[0]])?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline: let's keep this issue in mind but move forward with the current implementation for now :)

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Aug 17, 2018

Would be great if you could address the issue you discussed with the device argument, glad you caught that.

Fix that here

Maybe we just shouldn't have a device argument to the PyFunc's predict method?

If we keep it then user can also a possibility to execute the processing on the GPU.

@smurching
Copy link
Collaborator

Tests passed & this LGTM - thanks @vfdev-5, merging :)!

@smurching smurching merged commit b00826f into mlflow:master Aug 17, 2018
@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Aug 17, 2018

@smurching glad to help to make this tool better !

jdlesage pushed a commit to jdlesage/mlflow that referenced this pull request Nov 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants