Skip to content

Commit

Permalink
Supports cosine distance (LocalOutlierFactor, ...) (#1050)
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
  • Loading branch information
xadupre committed Dec 11, 2023
1 parent 78933db commit ae29a33
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.16.0

* Supports cosine distance (LocalOutlierFactor, ...)
[#1050](https://github.com/onnx/sklearn-onnx/pull/1050),
* Add an example on how to handle FunctionTransformer
[#1042](https://github.com/onnx/sklearn-onnx/pull/1042),
Versions of `scikit-learn < 1.0` are not tested any more.
Expand Down
74 changes: 57 additions & 17 deletions skl2onnx/algebra/complex_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
from ..common.data_types import FloatTensorType, DoubleTensorType
from ..common.utils import get_unique_subgraph
from .onnx_ops import (
OnnxAbs,
OnnxDiv,
OnnxIdentity,
OnnxMatMul,
OnnxPow,
OnnxScan,
OnnxTranspose,
OnnxSub,
OnnxReduceSumSquareApi18,
OnnxSqrt,
OnnxPow,
OnnxAbs,
OnnxSub,
OnnxReduceSumApi11,
OnnxReduceSumSquareApi18,
OnnxTranspose,
)

logger = getLogger("skl2onnx")
Expand Down Expand Up @@ -70,7 +72,7 @@ def _onnx_squareform_pdist_sqeuclidean(X, dtype=None, op_version=None, **kwargs)
num_scan_inputs=1,
body=(scan_body.graph, [id_next, flat]),
op_version=op_version,
**kwargs
**kwargs,
)
logger.debug("[_onnx_squareform_pdist_sqeuclidean] +Scan dtype=%r", dtype)
return node[1]
Expand All @@ -84,7 +86,7 @@ def onnx_cdist(
op_version=None,
dim_in=None,
dim_out=None,
**kwargs
**kwargs,
):
"""
Returns the ONNX graph which computes
Expand All @@ -110,14 +112,14 @@ def onnx_cdist(
op_version=op_version,
dim_in=dim_in,
dim_out=dim_out,
**kwargs
**kwargs,
)
elif metric == "euclidean":
if metric == "euclidean":
res = _onnx_cdist_sqeuclidean(
XA, XB, dtype=dtype, op_version=op_version, dim_in=dim_in, dim_out=dim_out
)
return OnnxSqrt(res, op_version=op_version, **kwargs)
elif metric == "minkowski":
if metric == "minkowski":
p = kwargs.pop("p")
res = _onnx_cdist_minkowski(
XA,
Expand All @@ -131,18 +133,27 @@ def onnx_cdist(
return OnnxPow(
res, np.array([1.0 / p], dtype=dtype), op_version=op_version, **kwargs
)
elif metric in ("manhattan", "cityblock"):
if metric in ("manhattan", "cityblock"):
return _onnx_cdist_manhattan(
XA,
XB,
dtype=dtype,
op_version=op_version,
dim_in=dim_in,
dim_out=dim_out,
**kwargs
**kwargs,
)
else:
raise NotImplementedError("metric='{}' is not implemented.".format(metric))
if metric == "cosine":
return _onnx_cdist_cosine(
XA,
XB,
dtype=dtype,
op_version=op_version,
dim_in=dim_in,
dim_out=dim_out,
**kwargs,
)
raise NotImplementedError(f"metric={metric!r} is not implemented.")


def _onnx_cdist_begin(op_version):
Expand Down Expand Up @@ -204,7 +215,7 @@ def _onnx_cdist_sqeuclidean(
op_version,
dim_in=dim_in,
dim_out=dim_out,
**kwargs
**kwargs,
)


Expand Down Expand Up @@ -233,7 +244,7 @@ def _onnx_cdist_minkowski(
op_version,
dim_in=dim_in,
dim_out=dim_out,
**kwargs
**kwargs,
)


Expand All @@ -258,5 +269,34 @@ def _onnx_cdist_manhattan(
op_version,
dim_in=dim_in,
dim_out=dim_out,
**kwargs
**kwargs,
)


def _onnx_cdist_cosine(
XA, XB, dtype=None, op_version=None, dim_in=None, dim_out=None, **kwargs
):
"""
Returns the ONNX graph which computes
``cdist(X, metric='cosine')``.
"""
txb = OnnxTranspose(XB, perm=[1, 0], op_version=op_version)
scal = OnnxMatMul(XA, txb, op_version=op_version)
norma = OnnxSqrt(
OnnxReduceSumSquareApi18(XA, axes=[1], keepdims=1, op_version=op_version),
op_version=op_version,
)
normb = OnnxSqrt(
OnnxReduceSumSquareApi18(txb, axes=[0], keepdims=1, op_version=op_version),
op_version=op_version,
)
return OnnxSub(
np.array([1], dtype=dtype),
OnnxDiv(
scal,
OnnxMatMul(norma, normb, op_version=op_version),
op_version=op_version,
),
op_version=op_version,
**kwargs,
)
35 changes: 34 additions & 1 deletion tests/test_algebra_onnx_operators_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,39 @@ def test_onnx_example_cdist_in_custom_ops(self):
exp = scipy_cdist(x * 2, x, metric="sqeuclidean")
assert_almost_equal(exp, res[0], decimal=4)

@unittest.skipIf(TARGET_OPSET < 10, reason="not available")
@unittest.skipIf(
pv.Version(ort_version) <= pv.Version(THRESHOLD2),
reason="fails with onnxruntime 0.4.0",
)
@ignore_warnings(category=DeprecationWarning)
def test_onnx_example_cdist_in_cosine(self):
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
x2 = (
np.array([1.1, 2.1, 4.01, 5.01, 5.001, 4.001, 0, 0])
.astype(np.float32)
.reshape((4, 2))
)
opv = _TARGET_OPSET_
cop = OnnxAdd("input", "input", op_version=opv)
cop2 = OnnxIdentity(
onnx_cdist(cop, x2, dtype=np.float32, metric="cosine", op_version=opv),
output_names=["cdist"],
op_version=opv,
)

model_def = cop2.to_onnx(
inputs=[("input", FloatTensorType([None, None]))],
outputs=[("cdist", FloatTensorType())],
)

sess = InferenceSession(
model_def.SerializeToString(), providers=["CPUExecutionProvider"]
)
res = sess.run(None, {"input": x})
exp = scipy_cdist(x * 2, x2, metric="cosine")
assert_almost_equal(exp, res[0], decimal=5)


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)
27 changes: 26 additions & 1 deletion tests/test_sklearn_local_outlier_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,31 @@ def test_local_outlier_factor_rnd(self):
assert_almost_equal(expected_label, got[0].ravel())
assert_almost_equal(expected_decif, got[1].ravel(), decimal=5)

@unittest.skipIf(LocalOutlierFactor is None, reason="old scikit-learn")
def test_local_outlier_factor_cosine(self):
lof = LocalOutlierFactor(n_neighbors=2, novelty=True, metric="cosine")
data = np.array(
[[-1.1, -1.2], [0.3, 0.2], [0.5, 0.4], [100.0, 99.0]], dtype=np.float32
)
model = lof.fit(data)
model_onnx = to_onnx(model, data, target_opset=TARGET_OPSET)
self.assertNotIn("CDist", str(model_onnx))

data = data.copy()
data[:, 0] += 0.1

sess = InferenceSession(
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
names = [o.name for o in sess.get_outputs()]
self.assertEqual(names, ["label", "scores"])
got = sess.run(None, {"X": data})
self.assertEqual(len(got), 2)
expected_label = lof.predict(data)
expected_decif = lof.decision_function(data)
assert_almost_equal(expected_label, got[0].ravel())
assert_almost_equal(expected_decif, got[1].ravel(), decimal=4)


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)

0 comments on commit ae29a33

Please sign in to comment.