Skip to content

Commit

Permalink
Adding onnxruntime_provider branch (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
bammari committed Sep 24, 2023
1 parent b60bf0d commit 47f646b
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tests/neuralnet/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def test_onnx_relu(datadir):
def obj(mdl):
return 1

net_regression = ort.InferenceSession(datadir.file("keras_linear_131_relu.onnx"))
net_regression = ort.InferenceSession(
datadir.file("keras_linear_131_relu.onnx"), providers=["CPUExecutionProvider"]
)

for x in [-0.25, 0.0, 0.25, 1.5]:
model.nn.inputs.fix(x)
Expand Down Expand Up @@ -93,7 +95,9 @@ def test_onnx_linear(datadir):
def obj(mdl):
return 1

net_regression = ort.InferenceSession(datadir.file("keras_linear_131.onnx"))
net_regression = ort.InferenceSession(
datadir.file("keras_linear_131.onnx"), providers=["CPUExecutionProvider"]
)

for x in [-0.25, 0.0, 0.25, 1.5]:
model.nn.inputs.fix(x)
Expand Down Expand Up @@ -134,7 +138,10 @@ def test_onnx_sigmoid(datadir):
def obj(mdl):
return 1

net_regression = ort.InferenceSession(datadir.file("keras_linear_131_sigmoid.onnx"))
net_regression = ort.InferenceSession(
datadir.file("keras_linear_131_sigmoid.onnx"),
providers=["CPUExecutionProvider"],
)

for x in [-0.25, 0.0, 0.25, 1.5]:
model.nn.inputs.fix(x)
Expand Down

0 comments on commit 47f646b

Please sign in to comment.