-
Notifications
You must be signed in to change notification settings - Fork 96
/
label_encoder.py
29 lines (23 loc) · 1.15 KB
/
label_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import numpy as np
from ..common._registration import register_converter
def convert_sklearn_label_encoder(scope, operator, container):
op = operator.raw_operator
op_type = 'LabelEncoder'
attrs = {'name': scope.get_unique_operator_name(op_type)}
classes = op.classes_
if np.issubdtype(classes.dtype, np.floating):
attrs['keys_floats'] = classes
elif np.issubdtype(classes.dtype, np.signedinteger):
attrs['keys_int64s'] = classes
else:
attrs['keys_strings'] = np.array([s.encode('utf-8') for s in classes])
attrs['values_int64s'] = np.arange(len(classes))
container.add_node(op_type, operator.input_full_names,
operator.output_full_names, op_domain='ai.onnx.ml',
op_version=2, **attrs)
register_converter('SklearnLabelEncoder', convert_sklearn_label_encoder)