Skip to content

Commit

Permalink
fix: support Slice operator in older ONNX op sets (#104)
Browse files Browse the repository at this point in the history
* fix Slice.from_onnx

* fixed outputs in Slice testcase

---------

Co-authored-by: M Baumann <marcelbaumann16@gmail.com>
Co-authored-by: David Shriver <davidshriver@outlook.com>
  • Loading branch information
3 people committed Feb 5, 2023
1 parent d39aeec commit f71c983
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
3 changes: 2 additions & 1 deletion dnnv/nn/operations/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def __init__(

@classmethod
def from_onnx(cls, onnx_node, *inputs):
return cls(*inputs, name=onnx_node.name)
attributes = {a.name: as_numpy(a) for a in onnx_node.attribute}
return cls(*inputs, **attributes, name=onnx_node.name)


class Tile(Operation):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ def test_from_onnx_Elu_unimplemented():
assert str(excinfo.value) == "Unimplemented operation type: Fake"


def test_from_onnx_Slice():
input_op = Input(np.array([-1, 5]), np.dtype(np.float32))

add_node = onnx.helper.make_node(
"Slice", inputs=["input"], outputs=["slice"], name="slice", starts=[0], ends=[1]
)
op_from_onnx = Operation.from_onnx(add_node, input_op)

assert type(op_from_onnx) is Slice
assert op_from_onnx.x is input_op
assert op_from_onnx.starts == [0]
assert op_from_onnx.ends == [1]


def test_inputs():
input_op_0 = Input((1, 4), np.dtype(np.float32))
add_op = Add(input_op_0, np.ones((1, 4), dtype=np.float32))
Expand Down

0 comments on commit f71c983

Please sign in to comment.