-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GaussianMixture #169
Add GaussianMixture #169
Changes from 21 commits
9c3d365
9152cb9
ede9829
9ff2aa5
37ee770
2a6f00f
8df0812
1c3b3e5
7934c4a
086f1fe
fa0517e
4e2053b
eb5923d
50311d5
4d82c47
4f4cb03
2baf6fd
8d64dd8
badb269
a72f77e
1bb81e0
784d64b
0047376
d18b90c
249b19f
1cae401
98e45ef
2e796a4
f7eecdd
41e8145
bbe958c
f4601b4
f79bfd2
d9febca
6aeca99
b638846
05fb4ce
10ee6c8
4b2f1d5
7495ffb
f34c8d3
169c658
a1559ac
8a647b9
391bd85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
import numpy as np | ||
from sklearn.mixture.gaussian_mixture import _compute_log_det_cholesky | ||
from ..common._registration import register_converter | ||
from ..algebra.onnx_ops import ( | ||
OnnxAdd, OnnxSub, OnnxMul, OnnxGemm, OnnxReduceSumSquare, | ||
OnnxReduceLogSumExp, OnnxExp, OnnxArgMax, OnnxConcat | ||
) | ||
|
||
|
||
def convert_sklearn_gaussian_mixture(scope, operator, container): | ||
""" | ||
Converter for *GaussianMixture*. | ||
Parameters which change the prediction function: | ||
|
||
* *covariance_type* | ||
""" | ||
X = operator.inputs[0] | ||
out = operator.outputs | ||
op = operator.raw_operator | ||
n_features = X.type.shape[1] | ||
n_components = op.means_.shape[0] | ||
|
||
# All comments come from scikit-learn code and tells | ||
# which functions is being onnxified. | ||
# def _estimate_weighted_log_prob(self, X): | ||
# self._estimate_log_prob(X) + self._estimate_log_weights() | ||
log_weights = np.log(op.weights_) # self._estimate_log_weights() | ||
|
||
# self._estimate_log_prob(X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commented code again? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reason. |
||
log_det = _compute_log_det_cholesky( | ||
op.precisions_cholesky_, op.covariance_type, n_features) | ||
|
||
if op.covariance_type == 'full': | ||
# shape(op.means_) = (n_components, n_features) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see a lot of commented code in this file, could you clean it up? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer to let it, it is how it is implemented in scikit-learn, I can add a new comment to specify it comes from sklearn. |
||
# shape(op.precisions_cholesky_) = | ||
# (n_components, n_features, n_features) | ||
|
||
# log_prob = np.empty((n_samples, n_components)) | ||
# for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)): | ||
# y = np.dot(X, prec_chol) - np.dot(mu, prec_chol) | ||
# log_prob[:, k] = np.sum(np.square(y), axis=1) | ||
|
||
ys = [] | ||
for c in range(n_components): | ||
prec_chol = op.precisions_cholesky_[c, :, :] | ||
cst = - np.dot(op.means_[c, :], prec_chol) | ||
y = OnnxGemm(X, prec_chol, cst, alpha=1., beta=1.) | ||
y2s = OnnxReduceSumSquare(y, axes=[1]) | ||
ys.append(y2s) | ||
log_prob = OnnxConcat(*ys, axis=1) | ||
|
||
elif op.covariance_type == 'tied': | ||
# shape(op.means_) = (n_components, n_features) | ||
# shape(op.precisions_cholesky_) = | ||
# (n_features, n_features) | ||
|
||
# log_prob = np.empty((n_samples, n_components)) | ||
# for k, mu in enumerate(means): | ||
# y = np.dot(X, precisions_chol) - np.dot(mu, precisions_chol) | ||
# log_prob[:, k] = np.sum(np.square(y), axis=1) | ||
|
||
precisions_chol = op.precisions_cholesky_ | ||
ys = [] | ||
for f in range(n_components): | ||
cst = - np.dot(op.means_[f, :], precisions_chol) | ||
y = OnnxGemm(X, precisions_chol, cst, alpha=1., beta=1.) | ||
y2s = OnnxReduceSumSquare(y, axes=[1]) | ||
ys.append(y2s) | ||
log_prob = OnnxConcat(*ys, axis=1) | ||
|
||
elif op.covariance_type == 'diag': | ||
# shape(op.means_) = (n_components, n_features) | ||
# shape(op.precisions_cholesky_) = | ||
# (n_components, n_features) | ||
|
||
# precisions = precisions_chol ** 2 | ||
# log_prob = (np.sum((means ** 2 * precisions), 1) - | ||
# 2. * np.dot(X, (means * precisions).T) + | ||
# np.dot(X ** 2, precisions.T)) | ||
|
||
precisions = op.precisions_cholesky_ ** 2 | ||
mp = np.sum((op.means_ ** 2 * precisions), 1) | ||
zeros = np.zeros((n_components, )) | ||
xmp = OnnxGemm(X, (op.means_ * precisions).T, zeros, | ||
alpha=-2., beta=0.) | ||
term = OnnxGemm(OnnxMul(X, X), precisions.T, zeros, alpha=1., beta=0.) | ||
log_prob = OnnxAdd(OnnxAdd(mp, xmp), term) | ||
|
||
elif op.covariance_type == 'spherical': | ||
# shape(op.means_) = (n_components, n_features) | ||
# shape(op.precisions_cholesky_) = (n_components, ) | ||
|
||
# precisions = precisions_chol ** 2 | ||
# log_prob = (np.sum(means ** 2, 1) * precisions - | ||
# 2 * np.dot(X, means.T * precisions) + | ||
# np.outer(row_norms(X, squared=True), precisions)) | ||
|
||
zeros = np.zeros((n_components, )) | ||
precisions = op.precisions_cholesky_ ** 2 | ||
normX = OnnxReduceSumSquare(X, axes=[1]) | ||
outer = OnnxGemm(normX, precisions[np.newaxis, :], zeros, | ||
alpha=1., beta=1.) | ||
xmp = OnnxGemm(X, (op.means_.T * precisions), zeros, | ||
alpha=-2., beta=0.) | ||
mp = np.sum(op.means_ ** 2, 1) * precisions | ||
log_prob = OnnxAdd(mp, OnnxAdd(xmp, outer)) | ||
else: | ||
raise RuntimeError("Unknown op.covariance_type='{}'. Upgrade " | ||
"to a mroe recent version of skearn-onnx " | ||
"or raise an issue.".format(op.covariance_type)) | ||
|
||
# -.5 * (cst + log_prob) + log_det | ||
cst = np.array([n_features * np.log(2 * np.pi)]) | ||
add = OnnxAdd(cst, log_prob) | ||
mul = OnnxMul(add, np.array([-0.5])) | ||
if isinstance(log_det, float): | ||
log_det = np.array([log_det]) | ||
weighted_log_prob = OnnxAdd(OnnxAdd(mul, log_det), log_weights) | ||
|
||
# labels | ||
labels = OnnxArgMax(weighted_log_prob, axis=1, output_names=out[:1]) | ||
|
||
# def _estimate_log_prob_resp(): | ||
# np.exp(log_resp) | ||
# weighted_log_prob = self._estimate_weighted_log_prob(X) | ||
# log_prob_norm = logsumexp(weighted_log_prob, axis=1) | ||
# with np.errstate(under='ignore'): | ||
# log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis] | ||
|
||
log_prob_norm = OnnxReduceLogSumExp(weighted_log_prob, axes=[1]) | ||
log_resp = OnnxSub(weighted_log_prob, log_prob_norm) | ||
|
||
# probabilities | ||
probs = OnnxExp(log_resp, output_names=out[1:]) | ||
|
||
# final | ||
labels.add_to(scope, container) | ||
probs.add_to(scope, container) | ||
|
||
|
||
register_converter('SklearnGaussianMixture', convert_sklearn_gaussian_mixture) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
from ..common._registration import register_shape_calculator | ||
from ..common.data_types import FloatTensorType, Int64TensorType | ||
from ..common.utils import ( | ||
check_input_and_output_numbers, | ||
check_input_and_output_types | ||
) | ||
|
||
|
||
def calculate_gaussian_mixture_output_shapes(operator): | ||
check_input_and_output_numbers(operator, input_count_range=1, | ||
output_count_range=2) | ||
check_input_and_output_types(operator, good_input_types=[FloatTensorType]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is int not allowed as an input type? Scikit allows int features. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hesitate. Statistically, it makes no sense to fix a gaussian mixture on integer data as it cannot be gaussian. I'll fix it. |
||
|
||
if len(operator.inputs[0].type.shape) != 2: | ||
raise RuntimeError('Input must be a [N, C]-tensor') | ||
|
||
op = operator.raw_operator | ||
N = operator.inputs[0].type.shape[0] | ||
operator.outputs[0].type = Int64TensorType([N, 1]) | ||
operator.outputs[1].type = FloatTensorType([N, op.n_components]) | ||
|
||
|
||
register_shape_calculator('SklearnGaussianMixture', | ||
calculate_gaussian_mixture_output_shapes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why commented code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To remember where I found the implementation in scikit-learn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better to have comments instead. That would make it clear to anyone reading the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I did (line 29). Does it need more?