-
Notifications
You must be signed in to change notification settings - Fork 123
Closed
Description
Hi,
I trained a SVMClassifier model with GridSearchCV and then converted it into a ONNX, but the resulting model is not valid (nodes of the models' graph are not topologically sorted).
I guess it is an issue related to the conversion of GridSearchCV and so it would be probably independent from the type of classifier chosen.
A possible workaround is to pass to to_onnx() method the best classifier found by GridSearchCV insead of the object returned by GridSearchCV.
The script added in the follow can be used to reproduce the issue and the workaround described above.
Best regards,
Biagio.
import numpy as np
import onnx
from sklearn import svm
from sklearn.model_selection import GridSearchCV
from sklearn import datasets
from sklearn.model_selection import train_test_split
from skl2onnx.helpers.onnx_helper import load_onnx_model
from skl2onnx import to_onnx
rand_seed = 0
np.random.seed(rand_seed)
def convert_to_onnx(sklearn_model, X, model_savename):
onnx_model_filename = model_savename + '.onnx'
onnx_model = to_onnx(sklearn_model, X[:1].astype(np.float32))
# Check the model
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print('The model is NOT valid:\n%s' % e)
else:
print('The model is valid!')
with open(onnx_model_filename, "wb") as f:
f.write(onnx_model.SerializeToString())
def load_train_test():
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=rand_seed)
return X_train, X_test, y_train, y_test
def train_svc_gs(X_train, y_train, apply_fix=False):
param_grid = {
'C': [0.1, 1, 1e1],
'gamma': [1e-3, 1e-2, 1e-1]
}
clf_est = svm.SVC(kernel='rbf', coef0=0.0, degree=3,
decision_function_shape='ovr',
probability=True)
if apply_fix:
gs_clf = GridSearchCV(clf_est, param_grid)
gs_clf.fit(X_train, y_train)
clf = gs_clf.best_estimator_
else:
clf = GridSearchCV(clf_est, param_grid)
clf.fit(X_train, y_train)
return clf
def train_svc(X_train, y_train):
clf = svm.SVC(kernel='rbf', coef0=0.0, C=10,
gamma=3e-2, degree=3,
decision_function_shape='ovr',
probability=True)
clf.fit(X_train, y_train)
return clf
def run(test_gs=True, apply_fix=False):
# Load train and test dataset
X_train, X_test, y_train, y_test = load_train_test()
assert apply_fix in [True, False], "Invalid value for apply_fix"
assert test_gs in [True, False], "Invalid value for test_gs"
if test_gs:
clf = train_svc_gs(X_train, y_train, apply_fix)
else:
clf = train_svc(X_train, y_train)
if test_gs:
if apply_fix:
onnx_model_name = "svc_gs_valid"
else:
onnx_model_name = "svc_gs_not_valid"
else:
onnx_model_name = "svc"
convert_to_onnx(clf, X_test, onnx_model_name)
if __name__ == "__main__":
print("SVC model trained with GridSearchCV (without workaround):")
run(test_gs=True, apply_fix=False)
print("\nSVC model trained with GridSearchCV (with workaround):")
run(test_gs=True, apply_fix=True)
print("\nSVC model trained without GridSearchCV:")
run(test_gs=False)
Metadata
Metadata
Assignees
Labels
No labels