-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch predictions for nearest neighbors #293
Conversation
@@ -5,498 +5,236 @@ | |||
# -------------------------------------------------------------------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to change the entire KNN converter implementation to support batch prediction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to reuse the cdist optimisation I made for the GaussianProcess. It was faster that way.
|
||
if training_labels.dtype == np.int32: | ||
training_labels = training_labels.astype(np.int64) | ||
extracted = OnnxArrayFeatureExtractor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is OnnxArrayFeatureExtractor different from ArrayFeatureExtractor? The reason why KNN converter couldn't handle batch prediction was ArrayFeatureExtractor as they can only handle one homogenous indices(Y is a tensor of int), which means you can't extract different values for different test example(based on their nearest neightbour).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the same. It is used to extract the neighbors labels. It is equivalent to GatherElements if the target dimension is one but still needed if the target dimension is more than one. The previous version had more than one ArrayFeatureExtractor.
@@ -150,3 +153,33 @@ def _onnx_cdist_minkowski(X, Y, dtype=None, op_version=None, p=2, **kwargs): | |||
op_version=op_version) | |||
return OnnxTranspose(node[1], perm=[1, 0], op_version=op_version, | |||
**kwargs) | |||
|
|||
|
|||
def _onnx_cdist_manhattan(X, Y, dtype=None, op_version=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to implement such a function for every metric we want to support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seemed easy to do. I refactored to avoid duplicated code.
|
||
def _onnx_cdist_manhattan(X, Y, dtype=None, op_version=None, **kwargs): | ||
""" | ||
Returns the ONNX graph which computes the Minkowski distance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minkowski distance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Manhattan. Sorry, wrong copy paste.
from ..common.data_types import Int64TensorType | ||
from ..algebra.onnx_ops import ( | ||
OnnxTopK, OnnxMul, OnnxArrayFeatureExtractor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's easier to locate these if sorted, given there are so many.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
def _get_weights(scope, container, topk_values_name, distance_power): | ||
Retrieves the nearest neigbours *ONNX*. | ||
:param X: features or *OnnxOperatorMixin* | ||
:param Y: neighbours or *OnnxOperatorMixin* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Y can be confusing here. Can you name them appropriately instead of X and Y.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced by XA, XB
opv = container.target_opset | ||
dtype = container.dtype | ||
|
||
if X.type.__class__ == Int64TensorType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we use isinstance()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all cdist return integer, it seems easier to cast now than after.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant can't you write:
if isinstance(X, Int64TensorType)
instead of
if X.type.class == Int64TensorType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
k = op.n_neighbors | ||
training_labels = op._y if hasattr(op, '_y') else None | ||
distance_kwargs = {} | ||
if metric == 'minkowski': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't we handling this with cdist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefered to have a cdist function which follow the scipy implementation and the converter choose a shorter path if it is more appropriate.
@@ -61,14 +61,14 @@ def _onnx_squareform_pdist_sqeuclidean(X, dtype=None, op_version=None, | |||
return node[1] | |||
|
|||
|
|||
def onnx_cdist(X, Y, metric='sqeuclidean', dtype=None, | |||
def onnx_cdist(XA, XB, metric='sqeuclidean', dtype=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason of having the names in capitals, pep8 recommends variable names to be in lower case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reused the same names scipy is using.
opv = container.target_opset | ||
dtype = container.dtype | ||
|
||
if X.type.__class__ == Int64TensorType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant can't you write:
if isinstance(X, Int64TensorType)
instead of
if X.type.class == Int64TensorType
This pull request introduces 4 alerts when merging 47f0afa into 4b769c5 - view on LGTM.com new alerts:
|
I get error with the int labels in regressor: data = load_digits() model = KNeighborsRegressor().fit(X_train, y_train) model_onnx = convert_sklearn(model, 'knn', [('input', Int64TensorType([None, X_test.shape[1]]))]) Fail Traceback (most recent call last) ~/Documents/MachineLearning/onnx_projects/tmp_env/lib/python3.6/site-packages/onnxruntime/capi/session.py in init(self, path_or_bytes, sess_options) ~/Documents/MachineLearning/onnx_projects/tmp_env/lib/python3.6/site-packages/onnxruntime/capi/session.py in _load_model(self, providers) Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from knn.onnx failed:Type Error: Type (tensor(float)) of output arg (variable) of node (Re_ReduceMean) does not match expected type (tensor(int64)). |
training_labels = training_labels.ravel() | ||
axis = 1 | ||
|
||
if training_labels.dtype == np.int32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do you handle string labels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently I did not, that also means it is not covered by any unit test. I'll need to add more tests tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we don't have unit tests for that. Also, if you are updating the unit tests, can you make them use fit_classification_model() and fit_regression_model() from tests_helper.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last failure comes from neighbours. When there neighbours are at the same exact distance, scikit-learn and onnx don't necessarily select the sames.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay.
Also, I see mismatches with integer features in KNN regressor: X, y = make_regression(n_samples=1000, n_features=100, random_state=42)
|
This pull request introduces 1 alert when merging 578e789 into dd9159c - view on LGTM.com new alerts:
|
|
||
if np.issubdtype(op.classes_.dtype, np.floating): | ||
classes = op.classes_.astype(np.int32) | ||
elif np.issubdtype(op.classes_.dtype, np.signedinteger): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both elif and else statements are same, you can just remove the elif part.
Converts *KNeighborsRegressor* into *ONNX*. | ||
The converted model may return different predictions depending | ||
on how the runtime select the topk element. | ||
*sciki-learn* uses function `argpartition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scikit
training_labels = training_labels.ravel() | ||
axis = 1 | ||
|
||
if training_labels.dtype == np.int32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay.
StrictVersion(onnxruntime.__version__) < StrictVersion("0.5.0"), | ||
reason="not available") | ||
def test_model_knn_classifier_multi_class_string(self): | ||
model, X = self._fit_model_multiclass_classification( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to clean up this unit test and use utility functions defined in test_utils.py later.
No description provided.