Skip to content

Commit

Permalink
Expose ONNX_ML build option to python (onnx#2138)
Browse files Browse the repository at this point in the history
* Expose ONNX_ML build option to python

* Add type
  • Loading branch information
bddppq committed Jun 28, 2019
1 parent 36c0d7d commit 556a5c3
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
1 change: 1 addition & 0 deletions onnx/__init__.py
Expand Up @@ -5,6 +5,7 @@

import os

from .onnx_cpp2py_export import ONNX_ML
from onnx.external_data_helper import load_external_data_for_model, write_external_data_tensors
from .onnx_pb import * # noqa
from .onnx_operators_pb import * # noqa
Expand Down
8 changes: 8 additions & 0 deletions onnx/cpp2py_export.cc
Expand Up @@ -19,6 +19,14 @@ using namespace pybind11::literals;
PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
onnx_cpp2py_export.doc() = "Python interface to onnx";

onnx_cpp2py_export.attr("ONNX_ML") = py::bool_(
#ifdef ONNX_ML
true
#else // ONNX_ML
false
#endif // ONNX_ML
);

// Submodule `schema`
auto defs = onnx_cpp2py_export.def_submodule("defs");
defs.doc() = "Schema submodule";
Expand Down
2 changes: 2 additions & 0 deletions onnx/onnx_cpp2py_export/__init__.pyi
@@ -1,3 +1,5 @@
# This is __init__.pyi, not __init__.py
# This directory is only considered for typing information, not for actual
# module implementations.

ONNX_ML : bool = ...
10 changes: 3 additions & 7 deletions onnx/test/shape_inference_test.py
Expand Up @@ -3,7 +3,7 @@
from __future__ import print_function
from __future__ import unicode_literals

from onnx import checker, helper, TensorProto, NodeProto, GraphProto, ValueInfoProto, ModelProto
from onnx import checker, helper, TensorProto, NodeProto, GraphProto, ValueInfoProto, ModelProto, ONNX_ML
from onnx.helper import make_node, make_tensor, make_tensor_value_info, make_empty_tensor_value_info, make_opsetid
from typing import Sequence, Union, Text, Tuple, List, Any, Optional
import onnx.shape_inference
Expand Down Expand Up @@ -1943,9 +1943,7 @@ def test_tile_rank_inference(self): # type: () -> None
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (None, None, None))]) # type: ignore

def test_linearclassifier_1D_input(self): # type: () -> None
onnx_ml = os.environ.get('ONNX_ML') # type: ignore
# No environment variable set (None) indicates ONNX_ML=1
if(onnx_ml is None or int(onnx_ml) != 0):
if ONNX_ML:
graph = self._make_graph(
[('x', TensorProto.FLOAT, (5,))],
[make_node('LinearClassifier', ['x'], ['y', 'z'], domain='ai.onnx.ml', coefficients=[0.0008, -0.0008], intercepts=[2.0, 2.0], classlabels_ints=[1, 2])],
Expand All @@ -1955,9 +1953,7 @@ def test_linearclassifier_1D_input(self): # type: () -> None
opset_imports=[make_opsetid('ai.onnx.ml', 1), make_opsetid('', 11)])

def test_linearclassifier_2D_input(self): # type: () -> None
onnx_ml = os.environ.get('ONNX_ML') # type: ignore
# No environment variable set (None) indicates ONNX_ML=1
if(onnx_ml is None or int(onnx_ml) != 0):
if ONNX_ML:
graph = self._make_graph(
[('x', TensorProto.FLOAT, (4, 5))],
[make_node('LinearClassifier', ['x'], ['y', 'z'], domain='ai.onnx.ml', coefficients=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], intercepts=[2.0, 2.0, 3.0], classlabels_ints=[1, 2, 3])],
Expand Down

0 comments on commit 556a5c3

Please sign in to comment.